Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:
hook(module, grad_input, grad_output) -> Tensor or None
grad_outputmay be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of
grad_inputin subsequent computations.
- a handle that can be used to remove the added hook by calling
The current implementation will not have the presented behavior for complex
Modulethat perform many operations. In some failure cases,
grad_outputwill only contain the gradients for a subset of the inputs and outputs. For such
Module, you should use
torch.Tensor.register_hookdirectly on a specific input or output to get the required gradients.