torchadf.nn.functional.flatten
- torchadf.nn.functional.flatten(in_mean, in_var, start_dim, end_dim, mode='diag')
Flattens the inputs along a contiguous range of dimensions.
Assumed Density Filtering (ADF) version of
torch.flatten. The dimensions to be flattened refer to the first input (mean) Tensor. Respective dimensions for the second input (covariance) Tensor are inferred according to the covariance propagation mode. (For the full covariance mode this can be ambiguous if the number of leading “batch” dimensions is unknown, hence we assume that any dimensions before the specified start_dim are batch dimensions.)- Parameters:
- in_mean
torch.Tensor Input mean tensor.
- in_var
torch.Tensor Input (co-)variance tensor.
- start_dim
int, optional First dimension to flatten (Default 1).
- end_dim: int, optional
Last dimension to flatten (Default -1).
- mode{“diag”, “diagonal”, “lowrank”, “half”, “full”}, optional
Covariance propagation mode (Default “diag”).
- in_mean
- Returns:
- out_mean
torch.Tensor The reshaped mean tensor.
- out_var
torch.Tensor The reshaped (co-)variance tensor.
- out_mean