analogvnn.graph.ModelGraphState
#
Module Contents#
Classes#
The state of a model graph. |
- class analogvnn.graph.ModelGraphState.ModelGraphState(use_autograd_graph: bool = False, allow_loops=False)[source]#
The state of a model graph.
- Variables:
allow_loops (bool) – if True, the graph is allowed to contain loops.
forward_input_output_graph (Optional[Dict[GRAPH_NODE_TYPE, InputOutput]]) – the input and output of the
pass. (forward) –
use_autograd_graph (bool) – if True, the autograd graph is used to calculate the gradients.
_loss (Tensor) – the loss.
INPUT (GraphEnum) – GraphEnum.INPUT
OUTPUT (GraphEnum) – GraphEnum.OUTPUT
STOP (GraphEnum) – GraphEnum.STOP
- Properties:
input (Tensor): the input of the forward pass. output (Tensor): the output of the forward pass. loss (Tensor): the loss.
- property inputs: Optional[analogvnn.graph.ArgsKwargs.ArgsKwargs][source]#
Get the inputs.
- Returns:
the inputs.
- Return type:
- property outputs: Optional[analogvnn.graph.ArgsKwargs.ArgsKwargs][source]#
Get the output.
- Returns:
the output.
- Return type:
- forward_input_output_graph: Optional[Dict[analogvnn.graph.GraphEnum.GRAPH_NODE_TYPE, analogvnn.graph.ArgsKwargs.InputOutput]][source]#
- _loss: Optional[torch.Tensor][source]#
- ready_for_forward(exception: bool = False) bool [source]#
Check if the state is ready for forward pass.
- Parameters:
exception (bool) – If True, an exception is raised if the state is not ready for forward pass.
- Returns:
True if the state is ready for forward pass.
- Return type:
- Raises:
RuntimeError – If the state is not ready for forward pass and exception is True.
- ready_for_backward(exception: bool = False) bool [source]#
Check if the state is ready for backward pass.
- Parameters:
exception (bool) – if True, raise an exception if the state is not ready for backward pass.
- Returns:
True if the state is ready for backward pass.
- Return type:
- Raises:
RuntimeError – if the state is not ready for backward pass and exception is True.
- set_loss(loss: Union[torch.Tensor, None]) ModelGraphState [source]#
Set the loss.
- Parameters:
loss (Tensor) – the loss.
- Returns:
self.
- Return type: