torchadf.nn.functional.flatten

torchadf.nn.functional.flatten(in_mean, in_var, start_dim, end_dim, mode='diag')

Flattens the inputs along a contiguous range of dimensions.

Assumed Density Filtering (ADF) version of torch.flatten. The dimensions to be flattened refer to the first input (mean) Tensor. Respective dimensions for the second input (covariance) Tensor are inferred according to the covariance propagation mode. (For the full covariance mode this can be ambiguous if the number of leading “batch” dimensions is unknown, hence we assume that any dimensions before the specified start_dim are batch dimensions.)

Parameters:
in_meantorch.Tensor

Input mean tensor.

in_vartorch.Tensor

Input (co-)variance tensor.

start_dimint, optional

First dimension to flatten (Default 1).

end_dim: int, optional

Last dimension to flatten (Default -1).

mode{“diag”, “diagonal”, “lowrank”, “half”, “full”}, optional

Covariance propagation mode (Default “diag”).

Returns:
out_meantorch.Tensor

The reshaped mean tensor.

out_vartorch.Tensor

The reshaped (co-)variance tensor.