-
Notifications
You must be signed in to change notification settings - Fork 1
/
dp_utils.py
407 lines (345 loc) · 14.4 KB
/
dp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math
from pathlib import Path
import pandas as pd
import datasets
from datasets import Dataset
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
Trainer,
TrainerCallback,
TrainerState,
TrainerControl,
logging,
DataCollatorForLanguageModeling,
PreTrainedTokenizer,
training_args,
modeling_utils,
get_cosine_schedule_with_warmup,
)
import transformers
from transformers.file_utils import is_sagemaker_mp_enabled, is_datasets_available
import opacus
from opacus.accountants import RDPAccountant
from prv_accountant import Accountant as PRVAccountant
from contextlib import contextmanager
from typing import Any, Callable, List, Optional, Union, Dict, Sequence
from dp_transformers import sampler, arguments
logger = logging.get_logger(__name__)
class DPCallback(TrainerCallback):
"""
This class registers all the necessary callbacks to make transformers.Trainer compatible with opacus.
"""
def __init__(
self,
noise_multiplier: float,
target_delta: float,
sampling_probability: float,
rdp_accountant: RDPAccountant,
prv_accountant: PRVAccountant,
max_epsilon: float = float("inf"),
) -> None:
self.noise_multiplier = noise_multiplier
self.target_delta = target_delta
self.sampling_probability = sampling_probability
self.rdp_accountant = rdp_accountant
self.prv_accountant = prv_accountant
self.max_epsilon = max_epsilon
self.on_substep_end_was_called = False
self.compute_rdp_epsilon = lambda: self.rdp_accountant.get_epsilon(
self.target_delta
)
self.compute_prv_epsilon = lambda s: self.prv_accountant.compute_epsilon(s)[2]
def on_substep_end(
self,
args: training_args.TrainingArguments,
state: TrainerState,
control: TrainerControl,
optimizer=None,
**kwargs,
):
if optimizer is None:
raise RuntimeError("Impossible to access optimizer from inside callback")
optimizer.signal_skip_step(do_skip=True)
optimizer.step()
optimizer.zero_grad()
self.on_substep_end_was_called = True
def on_step_end(
self,
args: training_args.TrainingArguments,
state: TrainerState,
control: TrainerControl,
optimizer=None,
**kwargs,
):
if not (
args.gradient_accumulation_steps <= 1 or self.on_substep_end_was_called
):
raise RuntimeError(
"Gradient accumulation was specified but `on_substep_end` wasn't called. "
"Make sure you're using a recent version of transformers (>=4.10.0) "
"which has an appropriate callback in the trainer."
)
if optimizer is None:
raise RuntimeError("Impossible to access optimizer from inside callback")
optimizer.zero_grad() # Opacus is bothered that HF does not call .zero_grad() on the optimizer
self.rdp_accountant.step(
noise_multiplier=self.noise_multiplier,
sample_rate=self.sampling_probability,
)
def on_save(
self,
args: training_args.TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
return self._check_max_epsilon_exceeded(state, control)
def on_evaluate(
self,
args: training_args.TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
return self._check_max_epsilon_exceeded(state, control)
def _check_max_epsilon_exceeded(
self, state: TrainerState, control: TrainerControl
) -> TrainerControl:
eps_rdp = self.compute_rdp_epsilon()
eps_prv = self.compute_prv_epsilon(state.global_step)
if eps_rdp > self.max_epsilon or eps_prv > self.max_epsilon:
logger.error("Max epsilon exceeded. Stopping training...")
control.should_training_stop = True
return control
class SaveModelConfigCallback(TrainerCallback):
def __init__(self, model, output_dir):
self.model = model
self.output_dir = output_dir
def on_epoch_end(self, args, state, control, **kwargs):
folder = f"{self.output_dir}/checkpoint-{state.global_step}"
Path(folder).mkdir(parents=True, exist_ok=True)
self.model.config.to_json_file(f"{folder}/config.json")
class DataCollatorForPrivateCausalLanguageModeling(DataCollatorForLanguageModeling):
def __init__(self, tokenizer: PreTrainedTokenizer):
super().__init__(tokenizer=tokenizer, mlm=False)
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
batch = super().__call__(examples)
# Huggingface's default way of constructing position_ids is not compatible with Opacus
# since Opacus is not able to deduce the batch size from the input. Here we manually
# generate a position_ids tensor which has the same values as Huggingface's default tensor
# but it is constructed in a way that is compatile with Opacus by using expand_as.
if "position_ids" not in batch:
input_ids = batch["input_ids"]
batch["position_ids"] = torch.arange(
input_ids.shape[1], dtype=torch.long, device=input_ids.device
).repeat(input_ids.shape[0], 1)
return batch
class GradSampleModule(opacus.GradSampleModule):
"""
Little wrapper to provide `no_sync` context which is assumed by Huggingface trainer.
We don't need to do anything in addition here
"""
@contextmanager
def no_sync(self):
yield
def create_author_mapping(dataset: Dataset, author: str) -> Sequence[Sequence[int]]:
"""
Creates a mapping from authors to samples in a dataset.
"""
with dataset.formatted_as(type="pandas"):
authors = pd.DataFrame(data={"author": dataset[author]})
author_mapping = [g.index.values for _, g in authors.groupby("author")]
return author_mapping
def get_optimizer(model, args: arguments.TrainingArguments, lr: float = 2e-4):
NROWS = 2173762
NGPU = 1
optim = transformers.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=lr
)
num_training_steps = (
math.ceil(NROWS // (args.per_device_train_batch_size * 2))
* NGPU
* args.num_train_epochs
)
scheduler = transformers.get_cosine_schedule_with_warmup(
optim, num_warmup_steps=0, num_training_steps=num_training_steps
)
return (optim, scheduler)
class OpacusDPTrainer(Trainer):
"""
Wrapper to modify Huggingface Trainer to:
(i) remove "loss = loss / self.args.gradient_accumulation_steps" operation in training_step
as this is already handled by Opacus package.
(ii) enable author-level DP training by modifing the sampler and the dataloader. In the case
of sample-level DP, each sample can be represented by a unique author.
(iii) wrap the optimizer with Opacus' DPOptimizer/DistributedDPOptimizer
"""
def __init__(
self,
model: Union[
modeling_utils.PreTrainedModel, torch.nn.modules.module.Module
] = None,
args: arguments.TrainingArguments = None,
train_dataset: Optional[torch.utils.data.dataset.Dataset] = None,
privacy_args: arguments.PrivacyArguments = None,
author_mapping: Optional[Sequence[Sequence[int]]] = None,
tokenizer: PreTrainedTokenizer = None,
**kwargs: Dict,
) -> None:
self.train_args = args
self.privacy_args = privacy_args
# Sample-level DP is equivalent to mapping each sample to a unique author.
if author_mapping is None:
author_mapping = [[i] for i in range(len(train_dataset))]
self.author_mapping = author_mapping
if not self.privacy_args.is_initialized:
self.privacy_args.initialize(
sampling_probability=self.sampling_probability,
num_steps=self.num_steps,
num_samples=len(self.author_mapping),
)
# Wrap model in DDP and GradSampleModule
if args.parallel_mode == training_args.ParallelMode.DISTRIBUTED:
logger.info(f"Wrapping the model with DPDDP in distributed training.")
model = opacus.distributed.DifferentiallyPrivateDistributedDataParallel(
model
)
# Add saveconfig callback
self.saveconfig_callback = SaveModelConfigCallback(model, args.output_dir)
model = GradSampleModule(model)
# Instantiate privacy accountants
self.rdp_accountant = RDPAccountant()
self.prv_accountant = PRVAccountant(
noise_multiplier=self.privacy_args.noise_multiplier,
sampling_probability=self.sampling_probability,
delta=self.privacy_args.target_delta,
eps_error=0.1,
max_compositions=self.num_steps,
)
# Set up callback for accounting and handling grad acc
self.dp_callback = DPCallback(
noise_multiplier=self.privacy_args.noise_multiplier,
target_delta=self.privacy_args.target_delta,
sampling_probability=self.sampling_probability,
rdp_accountant=self.rdp_accountant,
prv_accountant=self.prv_accountant,
)
super().__init__(
model=model,
args=args,
train_dataset=train_dataset,
callbacks=[self.dp_callback, self.saveconfig_callback],
tokenizer=tokenizer,
optimizers=get_optimizer(model, args),
**kwargs,
)
self.get_rdp_epsilon = lambda: self.rdp_accountant.get_epsilon(
self.privacy_args.target_delta
) # RDP epsilon
self.get_prv_epsilon = lambda: self.prv_accountant.compute_epsilon(
self.state.global_step
)[2]
@property
def sampling_probability(self) -> float:
return (
self.train_args.per_device_train_batch_size
* self.train_args.world_size
* self.train_args.gradient_accumulation_steps
/ len(self.author_mapping)
)
@property
def num_steps(self) -> int:
return int(
self.train_args.num_train_epochs * (1 / self.sampling_probability + 1)
)
def create_optimizer(self):
_ = super().create_optimizer()
if self.args.parallel_mode == training_args.ParallelMode.DISTRIBUTED:
optimizer_generator = opacus.optimizers.DistributedDPOptimizer
else:
optimizer_generator = opacus.optimizers.DPOptimizer
self.optimizer = optimizer_generator(
optimizer=self.optimizer,
noise_multiplier=self.privacy_args.noise_multiplier,
max_grad_norm=self.privacy_args.per_sample_max_grad_norm,
expected_batch_size=self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps,
)
return self.optimizer
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
raise NotImplementedError("DP currently doesn't support this")
if self.use_cuda_amp or self.use_cpu_amp:
raise NotImplementedError("DP currently doesn't support this.")
else:
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
# Compared to the original HF implementation, we have to remove the loss scaling by the number of gradient
# accumulation steps since opacus scales the gradients accordingly. However, we still need to scale the loss
# that is returned in order for the logging to work correctly. Hence we scale the loss after the call to
# loss.backward()
if self.use_cuda_amp or self.use_cpu_amp:
raise NotImplementedError("DP currently doesn't support this")
elif self.use_apex:
raise NotImplementedError("DP currently doesn't support this")
elif self.deepspeed:
raise NotImplementedError("DP currently doesn't support this")
else:
loss.backward()
return loss.detach() / self.args.gradient_accumulation_steps
def _get_train_sampler(self):
"""
Provides author sampler.
"""
train_sampler = sampler.ShuffledAuthorSampler(
author_mapping=self.author_mapping,
batch_size=self.args.per_device_train_batch_size,
world_size=self.args.world_size,
)
return train_sampler
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use the author-level sampler from dp_transformers.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = self._get_train_sampler()
train_dataset = self.train_dataset
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
return DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)