Welcome to torch-influence’s API Reference!
Table of Contents
Base Modules
- class torch_influence.BaseInfluenceModule(model, objective, train_loader, test_loader, device)[source]
The core module that contains convenience methods for computing influence functions.
- Parameters
model (
torch.nn.Module
) – the model of interest.objective (
BaseObjective
) – an implementation ofBaseObjective
.train_loader (
torch.utils.data.DataLoader
) – a training dataset loader.test_loader (
torch.utils.data.DataLoader
) – a test dataset loader.device (
torch.device
) – the device on which operations are performed.
- abstract inverse_hvp(vec)[source]
Computes an inverse-Hessian vector product, where the Hessian is specifically that of the (mean) empirical risk over the training dataset.
- Parameters
vec (
torch.Tensor
) – a vector.- Return type
torch.Tensor
- Returns
the inverse-Hessian vector product.
- train_loss_grad(train_idxs)[source]
Returns the gradient of the (mean) training loss over a set of training data points with respect to the model’s flattened parameters.
- Parameters
train_idxs (
List
[int
]) – the indices of the training points.- Return type
torch.Tensor
- Returns
the loss gradient at the training points.
- test_loss_grad(test_idxs)[source]
Returns the gradient of the (mean) test loss over a set of test data points with respect to the model’s flattened parameters.
- Parameters
test_idxs (
List
[int
]) – the indices of the test points.- Return type
torch.Tensor
- Returns
the loss gradient at the test points.
- stest(test_idxs)[source]
This function simply composes
inverse_hvp()
withtest_loss_grad()
.In the original influence function paper, the resulting vector was called \(\mathbf{s}_{\mathrm{test}}\).
- Parameters
test_idxs (
List
[int
]) – the indices of the test points.- Return type
torch.Tensor
- Returns
the \(\mathbf{s}_{\mathrm{test}}\) vector.
- influences(train_idxs, test_idxs, stest=None)[source]
Returns the influence scores of a set of training data points with respect to the (mean) test loss over a set of test data points.
Specifically, this method returns a 1D tensor of
len(train_idxs)
influence scores. These scores estimate the following quantities:Let \(\mathcal{L}_0\) be the (mean) test loss of the current model over the input test points. Suppose we produce a new model by (1) removing the
train_idxs[i]
-th example from the training dataset and (2) retraining the model on this one-smaller dataset. Let \(\mathcal{L}\) be the (mean) test loss of the new model over the input test points. Then thei
-th influence score estimates \(\mathcal{L} - \mathcal{L}_0\).- Parameters
train_idxs (
List
[int
]) – the indices of the training points.test_idxs (
List
[int
]) – the indices of the test points.stest (
Optional
[torch.Tensor
]) – this method requires the \(\mathbf{s}_{\mathrm{test}}\) vector of the input test points. If notNone
, this argument will be used taken as \(\mathbf{s}_{\mathrm{test}}\). Otherwise, \(\mathbf{s}_{\mathrm{test}}\) will be computed internally withstest()
.
- Return type
torch.Tensor
- Returns
the influence scores.
- class torch_influence.BaseObjective[source]
An abstract adapter that provides torch-influence with project-specific information about how training and test objectives are computed.
In order to use torch-influence in your project, a subclass of this module should be created that implements this module’s four abstract methods.
- abstract train_outputs(model, batch)[source]
Returns a batch of model outputs (e.g., logits, probabilities) from a batch of data.
- Parameters
model (
torch.nn.Module
) – the model.batch (
Any
) – a batch of training data.
- Return type
torch.Tensor
- Returns
the model outputs produced from the batch.
- abstract train_loss_on_outputs(outputs, batch)[source]
Returns the mean-reduced loss of the model outputs produced from a batch of data.
- Parameters
outputs (
torch.Tensor
) – a batch of model outputs.batch (
Any
) – a batch of training data.
- Return type
torch.Tensor
- Returns
the loss of the outputs over the batch.
Note
There may be some ambiguity in how to define
train_outputs()
andtrain_loss_on_outputs()
: what point in the forward pass deliniates outputs from loss function? For example, in binary classification, the outputs can reasonably be taken to be the model logits or normalized probabilities.For standard use of influence functions, both choices produce the same behaviour. However, if using the Gauss-Newton Hessian approximation for influence functions, we require that
train_loss_on_outputs()
be convex in the model outputs.See also
- abstract train_regularization(params)[source]
Returns the regularization loss at a set of model parameters.
- Parameters
params (
torch.Tensor
) – a flattened vector of model parameters.- Return type
torch.Tensor
- Returns
the regularization loss.
- train_loss(model, params, batch)[source]
Returns the mean-reduced regularized loss of a model over a batch of data.
This method should not be overridden for most use cases. By default, torch-influence takes and expects the overall training loss to be:
outputs = train_outputs(model, batch) loss = train_loss_on_outputs(outputs, batch) + train_regularization(params)
- Parameters
model (
torch.nn.Module
) – the model.params (
torch.Tensor
) – a flattened vector of the model’s parameters.batch (
Any
) – a batch of training data.
- Return type
torch.Tensor
- Returns
the training loss over the batch.
- abstract test_loss(model, params, batch)[source]
Returns the mean-reduced loss of a model over a batch of data.
- Parameters
model (
torch.nn.Module
) – the model.params (
torch.Tensor
) – a flattened vector of the model’s parameters.batch (
Any
) – a batch of test data.
- Return type
torch.Tensor
- Returns
the test loss over the batch.
Influence Modules
torch-influence provides three subclasses of BaseInfluenceModule
out-of-the-box.
Each subclass differs only in how the abstract function BaseInfluenceModule.inverse_hvp()
is implemented. We refer readers to the original influence function
paper (Koh & Liang, 2017) for further details.
- class torch_influence.AutogradInfluenceModule(model, objective, train_loader, test_loader, device, damp, check_eigvals=False)[source]
Bases:
BaseInfluenceModule
An influence module that computes inverse-Hessian vector products by directly forming and inverting the risk Hessian matrix using
torch.autograd
utilities.- Parameters
model (
torch.nn.Module
) – the model of interest.objective (
BaseObjective
) – an implementation ofBaseObjective
.train_loader (
torch.utils.data.DataLoader
) – a training dataset loader.test_loader (
torch.utils.data.DataLoader
) – a test dataset loader.device (
torch.device
) – the device on which operations are performed.damp (
float
) – the damping strength \(\lambda\). Influence functions assume that the risk Hessian \(\mathbf{H}\) is positive definite, which often fails to hold for neural networks. Hence, a damped risk Hessian \(\mathbf{H} + \lambda\mathbf{I}\) is used instead, for some sufficiently large \(\lambda > 0\) and identity matrix \(\mathbf{I}\).check_eigvals (
bool
) – ifTrue
, this initializer checks that the damped risk Hessian is positive definite, and raises aValueError
if it is not. Otherwise, no check is performed.
Warning
This module scales poorly with the number of model parameters \(d\). In general, computing the Hessian matrix takes \(\mathcal{O}(nd^2)\) time and inverting it takes \(\mathcal{O}(d^3)\) time, where \(n\) is the size of the training dataset.
- class torch_influence.CGInfluenceModule(model, objective, train_loader, test_loader, device, damp, gnh=False, **kwargs)[source]
Bases:
BaseInfluenceModule
An influence module that computes inverse-Hessian vector products using the method of (truncated) Conjugate Gradients (CG).
This module relies
scipy.sparse.linalg.cg()
to perform CG.- Parameters
model (
torch.nn.Module
) – the model of interest.objective (
BaseObjective
) – an implementation ofBaseObjective
.train_loader (
torch.utils.data.DataLoader
) – a training dataset loader.test_loader (
torch.utils.data.DataLoader
) – a test dataset loader.device (
torch.device
) – the device on which operations are performed.damp (
float
) – the damping strength \(\lambda\). Influence functions assume that the risk Hessian \(\mathbf{H}\) is positive-definite, which often fails to hold for neural networks. Hence, a damped risk Hessian \(\mathbf{H} + \lambda\mathbf{I}\) is used instead, for some sufficiently large \(\lambda > 0\) and identity matrix \(\mathbf{I}\).gnh (
bool
) – ifTrue
, the risk Hessian \(\mathbf{H}\) is approximated with the Gauss-Newton Hessian, which is positive semi-definite. Otherwise, the risk Hessian is used.**kwargs – keyword arguments which are passed into the “Other Parameters” of
scipy.sparse.linalg.cg()
.
- class torch_influence.LiSSAInfluenceModule(model, objective, train_loader, test_loader, device, damp, repeat, depth, scale, gnh=False, debug_callback=None)[source]
Bases:
BaseInfluenceModule
An influence module that computes inverse-Hessian vector products using the Linear time Stochastic Second-Order Algorithm (LiSSA).
At a high level, LiSSA estimates an inverse-Hessian vector product by using truncated Neumann iterations:
\[\mathbf{H}^{-1}\mathbf{v} \approx \frac{1}{R}\sum\limits_{r = 1}^R \left(\sigma^{-1}\sum_{t = 1}^{T}(\mathbf{I} - \sigma^{-1}\mathbf{H}_{r, t})^t\mathbf{v}\right)\]Here, \(\mathbf{H}\) is the risk Hessian matrix and \(\mathbf{H}_{r, t}\) are loss Hessian matrices over batches of training data drawn randomly with replacement (we also use a batch size in
train_loader
). In addition, \(\sigma > 0\) is a scaling factor chosen sufficiently large such that \(\sigma^{-1} \mathbf{H} \preceq \mathbf{I}\).In practice, we can compute each inner sum recursively. Starting with \(\mathbf{h}_{r, 0} = \mathbf{v}\), we can iteratively update for \(T\) steps:
\[\mathbf{h}_{r, t} = \mathbf{v} + \mathbf{h}_{r, t - 1} - \sigma^{-1}\mathbf{H}_{r, t}\mathbf{h}_{r, t - 1}\]where \(\mathbf{h}_{r, T}\) will be equal to the \(r\)-th inner sum.
- Parameters
model (
torch.nn.Module
) – the model of interest.objective (
BaseObjective
) – an implementation ofBaseObjective
.train_loader (
torch.utils.data.DataLoader
) – a training dataset loader.test_loader (
torch.utils.data.DataLoader
) – a test dataset loader.device (
torch.device
) – the device on which operations are performed.damp (
float
) – the damping strength \(\lambda\). Influence functions assume that the risk Hessian \(\mathbf{H}\) is positive-definite, which often fails to hold for neural networks. Hence, a damped risk Hessian \(\mathbf{H} + \lambda\mathbf{I}\) is used instead, for some sufficiently large \(\lambda > 0\) and identity matrix \(\mathbf{I}\).repeat (
int
) – the number of trials \(R\).depth (
int
) – the recurrence depth \(T\).scale (
float
) – the scaling factor \(\sigma\).gnh (
bool
) – ifTrue
, the risk Hessian \(\mathbf{H}\) is approximated with the Gauss-Newton Hessian, which is positive semi-definite. Otherwise, the risk Hessian is used.debug_callback (
Optional
[Callable
[[int
,int
,torch.Tensor
],None
]]) – a callback function which is passed in \((r, t, \mathbf{h}_{r, t})\) at each recurrence step.