torchadf.nn.modules.flatten.Unflatten
- class torchadf.nn.modules.flatten.Unflatten(dim, unflattened_size, mode='diag')
Unflattens a dimension of the input over multiple dimensions.
Assumed Density Filtering (ADF) version of
torch.nn.Unflatten. The dimension to be unflattened and target shape refer to the first input (mean) Tensor. Respective dimensions for the second input (covariance) Tensor are inferred according to the covariance propagation mode. (For the full covariance mode this can be ambiguous if the number of leading “batch” dimensions is unknown, hence we assume that any dimensions before the specified unflatten dimension are batch dimensions.)- Parameters:
- dim
int Dimension to unflatten.
- unflattened_size: tuple of int
Shape into which the selected dimension should be unflattened.
- mode{“diag”, “diagonal”, “lowrank”, “half”, “full”}, optional
Covariance propagation mode (Default “diag”).
- dim