torchadf.nn.functional.linear
- torchadf.nn.functional.linear(in_mean, in_var, weight, bias=None, mode='diag')
Applies a dense (fully connected) linear transform.
Assumed Density Filtering (ADF) version of
torch.nn.functional.linear.- Parameters:
- in_mean
torch.Tensor Input mean tensor. Expected shape is
(*, in_dim).- in_var
torch.Tensor Input (co-)variance tensor. Expected shape is
(*, in_dim),(*, in_dim, rank), or(*, in_dim, in_dim)depending on the mode.- weight
torch.Tensorortorch.nn.parameter.Parameter Weight matrix of the affine linear transform.
- bias
torch.Tensorortorch.nn.parameter.Parameter, optional Bias vecotr of the affine linear transform (Default None).
- mode{“diag”, “diagonal”, “lowrank”, “half”, “full”}, optional
Covariance propagation mode (Default “diag”).
- in_mean
- Returns:
- out_mean
torch.Tensor The transformed mean tensor of shape
(*, out_dim).- out_var
torch.Tensor The transformed (co-)variance tensor of shape
(*, out_dim),(*, out_dim, rank), or(*, out_dim, out_dim)depending on the mode.
- out_mean