analogvnn.backward.BackwardFunction
#
Module Contents#
Classes#
The backward module that uses a function to compute the backward gradient. |
- class analogvnn.backward.BackwardFunction.BackwardFunction(backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE, layer: torch.nn.Module = None)[source]#
Bases:
analogvnn.backward.BackwardModule.BackwardModule
,abc.ABC
The backward module that uses a function to compute the backward gradient.
- Variables:
_backward_function (TENSOR_CALLABLE) – The function used to compute the backward gradient.
- property backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE[source]#
The function used to compute the backward gradient.
- Returns:
The function used to compute the backward gradient.
- Return type:
TENSOR_CALLABLE
- set_backward_function(backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE) BackwardFunction [source]#
Sets the function used to compute the backward gradient with.
- Parameters:
backward_function (TENSOR_CALLABLE) – The function used to compute the backward gradient with.
- Returns:
self.
- Return type:
- backward(*grad_output: torch.Tensor, **grad_output_kwarg: torch.Tensor) analogvnn.utils.common_types.TENSORS [source]#
Computes the backward gradient of inputs with respect to outputs using the backward function.
- Parameters:
*grad_output (Tensor) – The gradients of the output of the layer.
**grad_output_kwarg (Tensor) – The gradients of the output of the layer.
- Returns:
The gradients of the input of the layer.
- Return type:
TENSORS
- Raises:
NotImplementedError – If the backward function is not set.