torchadf.nn.functional.linear

torchadf.nn.functional.linear(in_mean, in_var, weight, bias=None, mode='diag')

Applies a dense (fully connected) linear transform.

Assumed Density Filtering (ADF) version of torch.nn.functional.linear.

Parameters:
in_meantorch.Tensor

Input mean tensor. Expected shape is (*, in_dim).

in_vartorch.Tensor

Input (co-)variance tensor. Expected shape is (*, in_dim), (*, in_dim, rank), or (*, in_dim, in_dim) depending on the mode.

weighttorch.Tensor or torch.nn.parameter.Parameter

Weight matrix of the affine linear transform.

biastorch.Tensor or torch.nn.parameter.Parameter, optional

Bias vecotr of the affine linear transform (Default None).

mode{“diag”, “diagonal”, “lowrank”, “half”, “full”}, optional

Covariance propagation mode (Default “diag”).

Returns:
out_meantorch.Tensor

The transformed mean tensor of shape (*, out_dim).

out_vartorch.Tensor

The transformed (co-)variance tensor of shape (*, out_dim), (*, out_dim, rank), or (*, out_dim, out_dim) depending on the mode.