Power of Hooks in Pytorch
What are hooks?
Pytorch allows you to add custom function calls to its module and tensor objects called hooks. The calls can both be added to the forward method of the object as well as the backward method. A hook added to the forward method will be called with the following arguments
- The instance of the module itself
- The input to the module
- The output of the forward method
def dropout_hook(self, module, input, output):
output = F.dropout2d(output, self.prob, True, False)
return output
Why hooks?
Now that we know what hooks are, it’s important to understand their use case. Most commonly, they are either used for debugging purposes, calculating model size or calculating the number of ops. Let’s say you imported a backbone from the torchvision package such as vgg16. If you wanted to calculate the number of ops for each layer you might be tempted to rewrite the entire vgg16 backbone with commands added to the forward method for calculating the ops. Instead a better way is to add a hook to the module without re-writing the code for vgg16.
Example: Adding Dropout to a CNN
Let’s demonstrate the power of hooks with an example of adding dropout after every conv2d layer of a CNN.
Adding the Hook
Let’s write the hook that will do apply the dropout. The hook takes in 3 arguments i.e. the module itself, the input to the module and the output generated by forward method of the module. Our hook will just apply the dropout function to the output and overwrite it. The dropout2d arguments include the tensor to modify, the probability of dropping. The training flag can be set to be true only when the model is training or a custom combination of your choosing. Finally, inplace will overwrite the contents of output without creating a new tensor. This will raise an exception during training as autograd requires all outputs to be in memory for gradient propagation so we keep it as False.
def dropout_hook(self, module, input, output):
output = F.dropout2d(output, self.prob, True, False)
return output
Now that the hook is ready, we need to register it to the model itself. The way to do that is to call the register_forward_hook
method of the module with the handle of the dropout_hook
. This will add the dropout hook to every layer of the model. We only need to add this to the output of the convolutional layers. For that we create another function called register_hook
def register_hook(self, module):
if isinstance(module, nn.Conv2d):
module.register_forward_hook(dropout_hook)
model.apply(register_hook)
The apply method is applied recursively to every nn.Module within the model so it is ensured that every conv2d layer will have the dropout hook added to it. For reconfigurability, let us create a dropout hook class that allows us to store the probability of dropping activation values as well. We will add a remove method as well that will remove the hooks added to the model.
class DropoutHook():
def __init__(self, prob):
self.prob = prob
self.handles = []
def register_hook(self, module):
if isinstance(module, nn.Conv2d):
self.handles += [module.register_forward_hook(self.dropout_hook)]
def dropout_hook(self, module, input, output):
output = F.dropout2d(output, self.prob, True, False)
return output
def remove(self):
for handle_ in self.handles:
handle_.remove()
Testing it Out
To test it out let’s import the vgg16 backbone from the torchvision package. We will apply a random input to the model and store it for reference. We will then apply our dropout hook and evaluate the model in both training and evaluation mode and compare the outputs.
import torch
import torchvision
# Load the pre-trained model
model = torchvision.models.vgg16(pretrained=True)
# Set to eval mode
model.eval()
# Create a random input vector and store the reference output
x = torch.randn((1,3,224,224))
refOut = model(x)
# Instantiate the hook class and register the hook to the model
dropout_ = DropoutHook(prob=0.2)
model.apply(dropout_.register_hook)
# Evaluate with the hooks enabled
outWithDropout = model(x)
# Remove the hook and re-evaluate
dropout_.remove()
outWithOutDropout = model(x)
# Compare the outputs with the reference
errDropoutModel = (outWithDropout - refOut).mean()
errWithoutDropoutModel = (outWithOutDropout - refOut).mean()
print("Dropout Model Error: {}, Non-dropout Model Error: {}".format(errDropoutModel, errWithoutDropoutModel))
The output clearly shows that the dropout hook is changing the outputs of the conv2d layers.
Downsides
- Hooks are not serializable which means so you cannot call torch.save on a model that has hooks
- Hook references are not maintained in the model. Instead you have to store the handle to each hook (see the
DropoutHook
class for an example)