loss_functions
Module containing loss functions for training,
MSE_Loss
wrapper for a simple MSE loss with the same return shape. Forward function takes
Args
yhat, ytrue
Returns
MSE loss mean
Methods:
.forward
.forward(
yhat, ytrue
)
GaussNLLLoss
Wrapper for a Gaussian negative log likelihood loss. Forward:
Args
yhat, ytrue
Returns
NLL loss mean
Methods:
.forward
.forward(
yhat, ytrue
)
GaussNLL_VAR0_Loss
Gaussian negative log likelihood loss with a trainable variance parameter. Forward:
Args
yhat, ytrue
Returns
NLL loss mean
Methods:
.forward
.forward(
yhat, ytrue
)
GaussNLL_VAR_Loss
Gaussian negative log likelihood loss treating logsigma as a second order polynomial expansion similar to the noisemodel in MAVE NN. Forward:
Args
yhat, ytrue
Returns
NLL loss mean
Methods:
.calc_logsigma
.calc_logsigma(
yhat
)
function to compute the variance based on a order polinomial expansion
.forward
.forward(
yhat, ytrue
)
NoiseLayer
NoiseLayer(
model_params
)
Base class for original MAVE-NN noise layers
Args
- model_params (dict) : Dictionary of model parameters.
- key : polynomial_order
Attributes
- poly_order (int) : Order of polynomial expansion for noise model.
Methods:
.compute_nlls
.compute_nlls(
yhat, ytrue
)
Compute the negative log-likelihoods for the given predictions and targets defined in the derived noise model classes.
Args
- yhat (torch.Tensor) : Predictions from the model.
- ytrue (torch.Tensor) : Targets for the model.
Returns
- Tensor : The negative log-likelihoods for each sample in the batch.
.forward
.forward(
yhat, ytrue
)
GaussianNoise
GaussianNoise(
model_params
)
A Gaussian noise distribution for GE regression
Args
- model_params (dict) : Dictionary of model parameters.
Attributes
- poly_order (int) : Order of polynomial expansion for noise model.
Methods: compute_nlls: Compute the negative log likelihood using the computed logsigma
Methods:
.compute_params
.compute_params(
yhat, y_true = None
)
Compute layer parameters governing p(y|yhat).
.compute_nlls
.compute_nlls(
yhat, ytrue
)
Compute negative log likelihood contributions for each datum.
KLD_diag_gaussians
.KLD_diag_gaussians(
mu: torch.Tensor, logvar: torch.Tensor, p_mu: torch.Tensor,
p_logvar: torch.Tensor
)
KL divergence between diagonal gaussian with prior diagonal gaussian.
Args
- mu (torch.Tensor) : mean of the posterior
- logvar (torch.Tensor) : log variance of the posterior
- p_mu (torch.Tensor) : mean of the prior
- p_logvar (torch.Tensor) : log variance of the prior
Returns
KL divergence (torch.Tensor)