Attribute-to-Delete: Machine Unlearning via Datamodel Matching

Introduction

Machine learning (ML) models are trained on data to identify patterns and make predictions. Typically, when a model is successfully trained on a dataset, the knowledge it gains from that data becomes deeply embedded in its parameters and is even extractable given access to the model. However, there are cases where a model developer may realize that it was problematic to train on a subset of the data after the fact, and want to remove the influence of certain data points from the parameters of a trained model.
Consider the setting where a language model is trained on text data, and you want to modify the model to “unlearn” all the text from Harry Potter (perhaps for copyright reasons)—you’re tasked with producing a new model that never “read” any Harry Potter.
At this point, you may have two obvious first ideas:
1. You can retrain your model on the full dataset, now excluding Harry potter. However, as models grow in size and training time, this can be become prohibitively expensive (both in terms of time and $$).
2. You may also try to corrupt your original model’s performance on Harry Potter such that it performs poorly on Harry Potter; while this could make it forget Harry Potter, in practice, the model will also likely “forget” lots of English, as much of Harry Potter is not specific to the wizard universe.
Is there something cheaper (than re-training) and more principled we can do to unlearn? This challenge has motivated a recent line of work on (approximate) machine unlearning, where the goal is to remove (or “unlearn”) the impact of a specific collection of training examples, called the “forget set” , from a trained machine learning model trained on a dataset , in a computationally efficient fashion.
Unlearning methods try to find a shortcut from the original model   to the retrained model
Unlearning methods try to find a shortcut from the original model to the retrained model
The typical notion of success for unlearning is that the unlearned model is indistinguishable from the model that has been fully retrained without the forget set (which we refer to as an “oracle”). As comparing overparameterized models directly is extremely difficult, in practice, we aim for indistinguishability of model outputs. This can be measured by retraining models on the dataset excluding the forget set (called the “retain set,” ) and then measuring the distance between their predictions and the predictions from unlearned models. The choice of distance metric is both very important and quite subtle; we propose a new metric called KL divergence Of Margins (KLoM).
A high-level description of how unlearning evaluation generally looks (using a similar framing to the Neurips 2023 Unlearning challenge).
A high-level description of how unlearning evaluation generally looks (using a similar framing to the Neurips 2023 Unlearning challenge).

The “Missing Targets Problem

Failure Modes for Gradient-Based Unlearning
The majority of existing unlearning algorithms start with and unlearn by fine-tuning the model. Specifically, they use some combination of:
  1. Gradient Ascent (GA) on the forget set , in order to undo the impact of the points we want to forget.
  1. Gradient Descent (GD) on the retain set , in order to reinforce the impact of the points that remain.
However, we find (as does other recent work) that this general approach comes with a significant set of drawbacks, which we collectively refer to as the missing targets problem.
  1. First, the assumption (underlying both gradient-based methods) that forget set points will increase in loss after unlearning and retain set points will not, does not always hold in practice. For example, if there are similar points in the forget and retain sets, excluding the forget set may increase the models loss on the similar points in the retain set; conversely, the model may perform just as well on the forget set after unlearning if the model can generalize sufficiently well from similar points in the retain set.
    1. For example, the text from Harry Potter “He made several important telephone calls and shouted a bit more” is very similar to text one might find in “The HitchHiker’s Guide to the Galaxy” or a legal case; thus, even after removing Harry Potter from our dataset, a perfect unlearner may be able to recite passages like this having learned from other texts.
  1. Second, even for a forget set point whose loss does increase, a perfect unlearner would not increase loss arbitrarily, but instead only until it reaches the expected loss under a perfectly retrained model — its “target value.”
Since we lack access to these target values, it is challenging to know when a given forget set point has been “unlearned” and thus many existing heuristic-based schemes often overshoot or undershoot the target loss for a given data point. Compounding this problem, is that different points may reach these different target values at different times over the run of the unlearning algorithm, meaning that no single stopping time achieves good unlearning performance across all points.
The figure below illustrates this phenomenon for a popular unlearning algorithm called SCRUB. Over iterations of the algorithm, different points are unlearned (and then subsequently “overshot”) at different points in time (Hayes et al., 2024).
Each line represents the unlearning quality for an individual data point, as a function of the number of iterations of SCRUB. Datapoints that worsen their unlearning-quality over time are highlighted in red. This is a particular issue for forget points.
Each line represents the unlearning quality for an individual data point, as a function of the number of iterations of SCRUB. Datapoints that worsen their unlearning-quality over time are highlighted in red. This is a particular issue for forget points.

Thought experiment: How well can we unlearn if we “cheat”?

In light of these challenges, let’s take a step back and ask how well the best (possibly impractical) gradient-based unlearning algorithm could possibly do. Given that the underlying issue above is that we don’t know how the oracle model behaves, consider a setting where we have access to the predictions of a fully-retrained model : in this case, can we efficiently fine-tune to match outputs of ?
To study this, we propose the Oracle Matching (OM) algorithm, where we assume sample access (the ability to query the model and receive predictions) to , where is the training algorithm. We run Stochastic Gradient Descent (SGD) using outputs (in this case logits) from the oracle evaluated on some set , minimizing the following loss:
Note that it is not obvious that this method should succeed — it might be the case that the unlearned model is sufficiently far away from our initial model that we need to take as my gradients steps as retraining from scratch; or that optimizing this objective overfits to the sampled points from the oracle that we fine-tune on.
Our first discovery is that OM performs extremely well at approximating the distribution of oracle predictions (as measured by KLOM scores, where lower is better), using only a fraction (< 5%) of the compute of retraining.
 
This plot shows the KLoM scores of different methods as a function of computational cost (rescaled with the cost of full retraining), we plot the 99th percentile of KLoM scores in the forget, retain, and validation sets, as well as their average. Observe that most methods perform well (low KLoM) on retain and validation points, but only Oracle Matching and retraining performs well on the forget set.
This plot shows the KLoM scores of different methods as a function of computational cost (rescaled with the cost of full retraining), we plot the 99th percentile of KLoM scores in the forget, retain, and validation sets, as well as their average. Observe that most methods perform well (low KLoM) on retain and validation points, but only Oracle Matching and retraining performs well on the forget set.
One potential reason, is that we don’t observe the same issue of unlearning quality reversing over time that we saw in SCRUB:
notion image
 
Absent a few points (red lines) in the retain and validation sets, unlike in SCRUB, in OM, points that are unlearned generally remain unlearned. This allows the OM algorithm to perform well over many of the points as long as we stop the algorithm after sufficiently many iterations.
One obvious question given the above framework, is how to choose the fine-tuning targets in in order to effectively “distill” the oracle model. We find that it is sufficient to finetune on the forget set (as expected since these are the points where the model’s behavior generally changes the most) and on a small fraction of the retain set (just of the retain set generally suffices). This makes OM highly efficient, as we don’t need to fine-tune on the entire training set.
It turns out that in simple linear settings we can precisely characterize theoretically why OM is more effective than standard gradient descent/ascent variants. At a high-level, this is because the convergence rate benefits significantly from realizability, the fact that since every minibatch in OM has labels generated from the same underlying model , progress on one minibatch is not reversed by a subsequent update. We explore this further in Section 5 of the paper.

Approximating Oracle Matching with Datamodels

We saw given access to predictions from a retrained model we can easily fine-tune the full model to match the retrained model with OM, but given that our entire aim is to compute this retrained model efficiently, it feels like we’re no closer to effective unlearning than when we started! But what if we could efficiently estimate these retrained model predictions themselves, and then plug these estimates into OM? For this to be a feasible approach, estimating the predictions of a retrained model would have to be much simpler than estimating the retrained model itself.
Conveniently, accurately producing this kind of estimation is precisely what’s been studied by a recent line of work on predictive data attribution; in particular, we take the formalization from Ilyas et al. 2022, which we summarize briefly below.
In predictive data attribution, the goal is to produce an estimator (or datamodel) that takes as input a training set and a target example , and outputs a prediction of the effect of training on and predicting on input Using our existing notation, a datamodel for is a function such that, for any , and training procedure
In short, datamodels is a modeling method that predicts how a model would perform having been trained on any subset of a dataset; it can be thought of a counterfactual predictor — what would the model have predicted if the model was only trained on a particular subset of the data.
notion image

Unlearning via Datamodel Matching

In light of the above discussion, we propose the following meta unlearning algorithm: Datamodel Matching.
Datamodel Matching (DMM): given a trained model, we first (a) use datamodels to predict the prediction of the model on a collection of points if it were re-trained without the forget set points; then (b) fine-tune the full model  to match these predicted outputs as in OM.
 
Algorithmically, this is, given datamodels: , we run SGD, minimizing the loss:
notion image
 
After training sufficiently accurate datamodels (estimated using 20K models) and substituting the predictions into the OM procedure, Voila! 🪄🪄
 
Here we show KLoM scores for different methods, plotted against the computational cost (proportional to the cost of fully retraining the model (the perfect unlearner). In the figure we also show DM-Direct, a method we introduce in Section 4 of the full paper.
Here we show KLoM scores for different methods, plotted against the computational cost (proportional to the cost of fully retraining the model (the perfect unlearner). In the figure we also show DM-Direct, a method we introduce in Section 4 of the full paper.
 
By comparing to prior methods and baselines (re-training), we can observe a few trends:
  1. DMM performs much better than all existing baselines (KLoM score is lower) and approaches the unlearning quality of the “ground-truth” retraining.
  1. At the same time, DMM is much cheaper than full-retraining (< 5% of training compute!)
  1. Many existing baselines can hurt unlearning-quality and are worse than doing-nothing (the problem being that while they sometimes unlearn the forget set, they also significantly perturb predictions on the retain set or hurt validation accuracy).

Next Steps

We saw that viewing the problem of unlearning through the lens of predictive data attribution gives us a promising new class of approximate unlearning algorithms. But further research is needed to make these methods computationally efficient and more broadly applicable. To highlight some directions:
  1. Extending techniques beyond classification: The proposed methods work well in smaller scale classification tasks, but extending them to larger-scale settings (such as ImageNet or language modeling) requires more work.
  1. Improving oracle matching: While OM is a simple and provably (as we show in simple linear settings) effective algorithm, better understanding training dynamics in OM and alternative (e.g., non-random) sampling strategies can help improve the algorithm further.
  1. Reducing computational costs: Though a lot of progress has been already made in improving the effectiveness and efficiency of data attribution, our meta algorithms presented here would directly benefit from further improvements that scale attribution better to larger models and datasets.
 
Interested? Reach out!
If you’re interested in this work, take a deeper look at our paper. Additionally, feel invited to reach out to the authors. If you found this blog post useful for your own work, please cite our paper!
Correspondence: royrinberg@g.harvard.edu, sp765@mit.edu
 
See this blog post for a recent survey on machine unlearning.
Initially spurred by regulations such as the EU’s Right to be Forgotten, machine unlearning has found a variety of recent goals including: removing the effect of toxic, outdated, or poisoned data; removing concepts; rectifying copyright infringement in generative models; and even LLM safety training.
See Section 3 in paper for details.
SCRUB runs a GD-based algorithm on the retain set for a moderate number of epochs, while also running a GA-based algorithm on the forget set for the first few epochs
 
 

Recommendations