(To run the code samples, visit the GitHub repo.)
Although deep learning is commonly associated with classification problems, it is also effective for regression tasks.
Regression means we are estimating some real number. Estimate usually means "give me the average". For example, we might ask, "given this store's location and current promotions, estimate how many units of item X will get sold tomorrow."
But it turns out you can predict not only the average of some value, but also the variance of that value. Predicting the stdev (or variance) is useful for example when you are planning for resources.
Perhaps tomorrow the store will sell average 10.0 units of item X, and, the stdev is 1.0. But you want to be 99% sure you'll have enough inventory on hand, and you think the distribution is basically Normal (or Gaussian). How much inventory will you need? Definitely more than 10.0... 10 would only be enough 50% of the time, and the other half the time, you'd be missing out of revenue. To estimate the number of items you need, you can look up the z-table, get a critical value of 2.33 for a one-tailed 99% confidence interface, and therefore plan to have 12.33 units of item X on hand. You'll only run out of inventory 1% of the time.
In regression, we predict some value based on input . For instance, a model might predict height (y) based on age (x1) and weight (x2). However, the model isn't predicting a specific person's height; rather, it predicts the average (or mean) height for all people with age x1 and weight x2.
As any introductory machine learning course teaches, the loss function for regression is Mean Squared Error (MSE). Minimizing MSE gives you the mean of the predicted value.
One core reason MSE is appropriate is that it assumes the variance (or standard deviation) of the errors (residuals) between predictions and actual data is constant. To explore why, let’s look at the probability density function (PDF) of a normal distribution:
In this equation, the term `(y - u) ** 2` represents the squared difference between the mean and a data point. When training a model using MSE, the model considers this as the error between a data point y and the predicted mean u, squares it, and then averages it across all data points. This turns out to be the heart of why we use MSE!
Why can we ignore the variance sigma? Because we assume the variance of the error is constant. This is called "homoscedasticity" ("homo" means "same"). When calculating the gradient, any constant value can be ignored.
What about the exponent function? Training neural networks is typically based on Maximum Likelihood Estimation (MLE), which finds the parameters that maximize the joint probability of the observed outcomes. Joint probability involves multiplying many small numbers—so for numerical stability, we take the logarithm, turning it into a sum of logs:
This is why MSE doesn’t have any pesky exponent or log functions. A logarithm applied to an exponent cancel each other out:
Well, if your variance is constant, you don't really need a model to predict it. You can measure it once and it'll always be the same.
However, when variance isn't constant (a situation known as "heteroscedasticity", "hetero" means "different"), things get more complicated.
The solution is to have your model output two heads: one for the mean of the value you are predicting and one for the variance of . Then, undo the shortcuts we used earlier (constant variance, log of exponent).
Using MLE, we still need to minimize the Negative Log-Likelihood (NLL). Likelihood, in frequentist statistics, is the same as probability, and probability is given by the PDF. So, we minimize:
where theta represents the predictions of your neural network, specifically the mean and variance.
Well, we know the pdf function from above... so we get the following:
After some basic manipulation and getting rid of constants (which are ignored during backprop), that's equivalent to minimizing the below.
Let's try to build some intuition about what this loss function is doing.
On the right side, we see our old MSE formulation. That much is obvious: the prediction for the mean of should be as close as possible, and roughly minimize the square of the error.
But we divide it by variance. What does that mean? Well, it means that the bigger the variance, the smaller
becomes. In other words, we decrease the penalty of the MSE term as we increase the variance. That much makes sense... if data is really disperse, it's fair to reduce the weight of the MSE penalty.
But on the other hand, we directly minimize the log(sigma) term. Since log(sigma) is proportional to just sigma, that means we want to minimize the variance. That also makes sense... we can't let backprop greedily just minimize the loss term by estimating an infinite variance.
To summarize, -log(pdf) simplifies to MSE when variance is constant. MSE aligns perfectly with the MLE concept of NLL, where the negative log cancels out with the exponent in the normal distribution's PDF, leaving us with the simple MSE equation. When we want to estimate variance, we can do so by minimizing the log of the pdf of the Normal distribution.
BTW clasification is a much more common task for deep learning, and when you are classifying something, you are working with a Generalized Bernoulli distribution. The NLL in that case is still -log(pdf), and if you use the PDF of the Bernoulli distribution, you can directly derive the equation for the crossentropy loss.
Although you could also get to the equation using the concepts of Shannon's Entropy and KL Divergence, and in fact that's why it's called "entropy"... in the context of MLE and NLL, we could just as well call the equation the NLL of the Generalized Bernoulli distribution.
Of course, just because we know that NLL tells us the most likely parameters of a model... it does not guarantee that backprop will get us there. For that, we need a property called convexity, and that's a whole other topic. Convexity exists for crossentropy when the activation function is sigmoid or softmax, and for MSE, but other distributions like Gaussian Mixture Models are prone to local minima.
Some of the first ideas here came from econometrics, where they wanted to estimate time-series data but realized that during some periods of time, there is higher variance than others. This book talks about it further in the context of ML, and of course we have to go to Fisher for MLE, Gauss and Laplace for CLT, and Rockafellar for convexity.
We'll train a model to estimate the expected value and variance of the output of a function, based on the input.
We'll have two clusters of data. If the input is between zero and one, the output is a normal distribution with mean 1 and stdev 1 (var 1). If the input is between two and three, the output is a normal distribution with mean -1 and stdev 2 (var 4)
import torch
import torch.nn as nn
x1 = torch.rand(10000, 1) # Uniform between 0 and 1.
y1 = torch.randn(10000, 1) + 1
x2 = torch.rand(10000, 1) + 2 # Uniform between 2 and 3.
y2 = torch.randn(10000, 1) * 2 -1
x = torch.cat([x1, x2])
y = torch.cat([y1, y2])
class Data(torch.utils.data.Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
dataloader = torch.utils.data.DataLoader(Data(x, y), batch_size=100, shuffle=True, drop_last=True)
The model is just an MLP (multilayer perceptron) with two heads. And the loss function is just NLL applied to the key parameters, or theta, of a Gaussian distribution: mean and variance.
class Predictor(nn.Module):
def __init__(self):
super(Predictor, self).__init__()
hidden_depth = 100
self.dense1 = nn.Linear(1, hidden_depth)
self.act1 = nn.Sigmoid()
self.dense2 = nn.Linear(hidden_depth, hidden_depth)
self.act2 = nn.Sigmoid()
self.dense3 = nn.Linear(hidden_depth, hidden_depth)
self.act3 = nn.Sigmoid()
self.dense4 = nn.Linear(hidden_depth, hidden_depth)
self.act4 = nn.Sigmoid()
self.dense5 = nn.Linear(hidden_depth, hidden_depth)
self.act5 = nn.Sigmoid()
self.dense6 = nn.Linear(hidden_depth, hidden_depth)
self.act6 = nn.Sigmoid()
self.dense7 = nn.Linear(hidden_depth, hidden_depth)
self.act7 = nn.Sigmoid()
self.dense8 = nn.Linear(hidden_depth, hidden_depth)
self.act8 = nn.Sigmoid()
self.dense9 = nn.Linear(hidden_depth, hidden_depth)
self.act9 = nn.Sigmoid()
self.dense10 = nn.Linear(hidden_depth, hidden_depth)
self.mean_y = nn.Linear(hidden_depth, 1)
self.var_y = nn.Linear(hidden_depth, 1)
def forward(self, x):
x = self.act1(self.dense1(x))
x = x + self.act2(self.dense2(x))
x = x + self.act3(self.dense3(x))
x = x + self.act4(self.dense4(x))
x = x + self.act5(self.dense5(x))
x = x + self.act6(self.dense5(x))
x = x + self.act7(self.dense5(x))
x = x + self.act8(self.dense5(x))
x = x + self.act9(self.dense5(x))
x = self.dense10(x)
mean = self.mean_y(x)
var = torch.exp(self.var_y(x)) # var should be positive
return mean, var
model = Predictor()
for p in model.parameters():
p.register_hook(lambda grad: torch.clamp(grad, -1.0, +1.0))
class MeanVarianceLoss(nn.Module):
"""Calculates the negative log likelihood of seeing a target value given a mean and a variance."""
def __init__(self):
super(MeanVarianceLoss, self).__init__()
def forward(self, mean, var, target):
normal = torch.distributions.Normal(mean, torch.sqrt(var))
# For numerical stability, the torch distributions library
# returns the log(pdf(target | mean, variance)), not the pdf(target | mean, variance) directly.
log_prob = normal.log_prob(target)
# NLL is the negative log probability
nll = -log_prob
return torch.mean(nll)
criterion = MeanVarianceLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1000.)
from tqdm import tqdm
for epoch in range(100):
if 0 <= epoch < 20:
optimizer.param_groups[0]['lr'] = 1e-3
elif 20 <= epoch < 40:
optimizer.param_groups[0]['lr'] = 1e-4
elif 40 <= epoch < 60:
optimizer.param_groups[0]['lr'] = 1e-5
elif 60 <= epoch < 80:
optimizer.param_groups[0]['lr'] = 1e-6
else:
optimizer.param_groups[0]['lr'] = 1e-7
for x, y in dataloader:
optimizer.zero_grad()
mean, var = model(x)
loss = criterion(mean, var, y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
Epoch 0, Loss: 2.313269853591919 Epoch 10, Loss: 1.6423704624176025 Epoch 20, Loss: 2.411954402923584 Epoch 30, Loss: 1.5081806182861328 Epoch 40, Loss: 1.5562080144882202 Epoch 50, Loss: 1.6598745584487915 Epoch 60, Loss: 1.8379545211791992 Epoch 70, Loss: 1.839219093322754 Epoch 80, Loss: 1.7417705059051514 Epoch 90, Loss: 1.529351830482483
mean, var = model(torch.tensor([0.5]))
print(f"Estimate of mean and var for first cluster of data should be mean=1, var=1.")
print(f"Mean estimate: {mean.item()}, Var estimate: {var.item()}")
print()
mean, var = model(torch.tensor([2.5]))
print(f"Estimate of mean and var for second cluster of data should be mean=-1, var=4.")
print(f"Mean estimate: {mean.item()}, Var estimate: {var.item()}")
Estimate of mean and var for first cluster of data should be mean=1, var=1. Mean estimate: 1.0142967700958252, Var estimate: 0.8464043736457825 Estimate of mean and var for second cluster of data should be mean=-1, var=4. Mean estimate: -0.8514268398284912, Var estimate: 3.932140827178955
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.
| User | Count |
|---|---|
| 46 | |
| 42 | |
| 38 | |
| 31 | |
| 28 | |
| 28 | |
| 27 | |
| 24 | |
| 24 | |
| 23 |