Towards a loss function for YOLO
posted on 2020-09-01T09:57:11Z · view page on GitHubWriting the second part of the YOLO series took a lot longer than I care to admit... but, it's finally here! In this part we'll go over the definition of a loss function to train the YOLO architecture.
Index¶
- part 1: introduction to YOLO and TinyYOLOv2 implementation
- part 2 (this post): towards a YOLO loss function
- part 3 (in progress): transfer learning with YOLO
In the previous post, we saw how to make bounding box predictions with YOLO. But how do they compare with the ground truth boxes? For this we need a way to load the ground truths:
def load_bboxes(idx, size, num_bboxes=None, device="cpu"):
boxfilename = f"VOCdevkit/VOC2012/Annotations/2008_{str(idx).zfill(6)}.xml"
imgfilename = f"VOCdevkit/VOC2012/JPEGImages/2008_{str(idx).zfill(6)}.jpg"
img = Image.open(imgfilename) # img won't be loaded into memory, just needed for metadata.
try:
width, height = size
except TypeError:
width = height = size
scale = min(width / img.width, height / img.height)
new_width, new_height = int(img.width * scale), int(img.height * scale)
diff_width = (width - new_width)/width*img.width
diff_height = (height - new_height)/height*img.height
bboxes = []
with open(boxfilename, "r") as file:
for line in file:
if "<name>" in line:
class_ = line.split("<name>")[-1].split("</name>")[0].strip()
elif "<xmin>" in line:
x0 = float(line.split("<xmin>")[-1].split("</xmin>")[0].strip())
elif "<xmax>" in line:
x1 = float(line.split("<xmax>")[-1].split("</xmax>")[0].strip())
elif "<ymin>" in line:
y0 = float(line.split("<ymin>")[-1].split("</ymin>")[0].strip())
elif "<ymax>" in line:
y1 = float(line.split("<ymax>")[-1].split("</ymax>")[0].strip())
elif "</object>" in line:
if class_ not in CLASS2NUM:
continue
bbox = [
(diff_width/2 + (x0+x1)/2) / (img.width + diff_width), # center x
(diff_height/2 + (y0+y1)/2) / (img.height + diff_height), # center y
(max(x0,x1) - min(x0,x1)) / (img.width + diff_width), # width
(max(y0,y1) - min(y0,y1)) / (img.height + diff_height), # height
1.0, # confidence
CLASS2NUM[class_], # class idx
]
bboxes.append((bbox[2]*bbox[3]*(img.width+diff_width)*(img.height+diff_height),bbox))
bboxes = torch.tensor([bbox for _, bbox in sorted(bboxes)], dtype=torch.get_default_dtype(), device=device)
if num_bboxes:
if num_bboxes > len(bboxes):
zeros = torch.zeros((num_bboxes - bboxes.shape[0], 6), dtype=torch.get_default_dtype(), device=device)
zeros[:, -1] = -1
bboxes = torch.cat([bboxes, zeros], 0)
elif num_bboxes < len(bboxes):
bboxes = bboxes[:num_bboxes]
return bboxes
def load_bboxes_batch(idxs, size, num_bboxes, device="cpu"):
bboxes = [load_bboxes(idx, size, num_bboxes=num_bboxes, device="cpu") for idx in idxs]
bboxes = torch.stack(bboxes, 0)
return bboxes
batch_imgs = load_image_batch([8,16,33,60], size=320)
batch_labels = load_bboxes_batch([8,16,33,60], size=330, num_bboxes=10)
batch_predictions = network(batch_imgs)
print("labels:")
show_images_with_boxes(batch_imgs, batch_labels)
print("predictions:")
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))
We can visually see that these labels correspond quite well to the predictions. But we need a way to quantify how well.
YOLO Loss function¶
Quantifying how well each predicted bounding box fits the bounding box labels can be done by defining a custom loss function. However, in the case of YOLO, this is a big deal. The loss function used to train YOLO from scratch is this monster (adapted for YOLOv2):
$$ \begin{align*} \lambda_\textbf{coord}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left[ \left( x_i - \hat{x}_i \right)^2 + \left( y_i - \hat{y}_i \right)^2 \right] \\ + \lambda_\textbf{coord}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left[ \left( \sqrt{w_i} - \sqrt{\hat{w}_i} \right)^2 + \left( \sqrt{h_i} - \sqrt{\hat{h}_i} \right)^2 \right] \\ + &\sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left( C_i - \hat{C}_i \right)^2 \\ + \lambda_\textrm{noobj}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{noobj}} \left( C_i - \hat{C}_i \right)^2 \\ + &\sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \sum_{c \in \textrm{classes}} \left( p_i(c) - \hat{p}_i(c) \right)^2 \end{align*} $$This loss function can be broken down as follows.
- The summation $\sum_{i = 0}^{M\times N}$ that's present in every term of the
loss is the sumation over each of the predictions (
# anchors
$\times$# x-values
$\times$# y-values
). - The first term of this loss function describes the sum-squared error between the center coordinates of each predicted bounding box.
- The second term describes the sum-squared error of the square root of the widths/heights. The square root is there to (partially) reduce the importance of variations on larger bounding boxes vs smaller bounding boxes.
- The third term describes the sum-squared error in the confidences for boxes that do contain objects
- The fourth term describes the sum-squared error in the confidences for boxes that don't contain objects
- Finally, the last term describes the sum-squared error on the classification vector, but only if an object is present in the bounding box.
class YOLOLoss(torch.nn.modules.loss._Loss):
""" A loss function to train YOLO v2
Args:
anchors (optional, list): the list of anchors (should be the same anchors as the ones defined in the YOLO class)
seen (optional, torch.Tensor): the number of images the network has already been trained on
coord_prefill (optional, int): the number of images for which the predicted bboxes will be centered in the image
threshold (optional, float): minimum iou necessary to have a predicted bbox match a target bbox
lambda_coord (optional, float): hyperparameter controlling the importance of the bbox coordinate predictions
lambda_noobj (optional, float): hyperparameter controlling the importance of the bboxes containing no objects
lambda_obj (optional, float): hyperparameter controlling the importance of the bboxes containing objects
lambda_cls (optional, float): hyperparameter controlling the importance of the class prediction if the bbox contains an object
"""
def __init__(
self,
anchors=(
(1.08, 1.19),
(3.42, 4.41),
(6.63, 11.38),
(9.42, 5.11),
(16.62, 10.52),
),
seen=0,
coord_prefill=12800,
threshold=0.6,
lambda_coord=1.0,
lambda_noobj=1.0,
lambda_obj=5.0,
lambda_cls=1.0,
):
super().__init__()
if not torch.is_tensor(anchors):
anchors = torch.tensor(anchors, dtype=torch.get_default_dtype())
else:
anchors = anchors.data.to(torch.get_default_dtype())
self.register_buffer("anchors", anchors)
self.seen = int(seen+.5)
self.coord_prefill = int(coord_prefill+.5)
self.threshold = float(threshold)
self.lambda_coord = float(lambda_coord)
self.lambda_noobj = float(lambda_noobj)
self.lambda_obj = float(lambda_obj)
self.lambda_cls = float(lambda_cls)
self.mse = torch.nn.MSELoss(reduction='sum')
self.cel = torch.nn.CrossEntropyLoss(reduction='sum')
def forward(self, x, y):
nT = y.shape[1]
nA = self.anchors.shape[0]
nB, _, nH, nW = x.shape
nPixels = nH * nW
nAnchors = nA * nPixels
y = y.to(dtype=x.dtype, device=x.device)
x = x.view(nB, nA, -1, nH, nW).permute(0, 1, 3, 4, 2)
nC = x.shape[-1] - 5
self.seen += nB
anchors = self.anchors.to(dtype=x.dtype, device=x.device)
coord_mask = torch.zeros(nB, nA, nH, nW, 1, requires_grad=False, dtype=x.dtype, device=x.device)
conf_mask = torch.ones(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device) * self.lambda_noobj
cls_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=torch.bool, device=x.device)
tcoord = torch.zeros(nB, nA, nH, nW, 4, requires_grad=False, dtype=x.dtype, device=x.device)
tconf = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device)
tcls = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device)
coord = torch.cat([
x[:, :, :, :, 0:1].sigmoid(), # X center
x[:, :, :, :, 1:2].sigmoid(), # Y center
x[:, :, :, :, 2:3], # Width
x[:, :, :, :, 3:4], # Height
], -1)
range_y, range_x = torch.meshgrid(
torch.arange(nH, dtype=x.dtype, device=x.device),
torch.arange(nW, dtype=x.dtype, device=x.device),
)
anchor_x, anchor_y = anchors[:, 0], anchors[:, 1]
x = torch.cat([
(x[:, :, :, :, 0:1].sigmoid() + range_x[None,None,:,:,None]), # X center
(x[:, :, :, :, 1:2].sigmoid() + range_y[None,None,:,:,None]), # Y center
(x[:, :, :, :, 2:3].exp() * anchor_x[None,:,None,None,None]), # Width
(x[:, :, :, :, 3:4].exp() * anchor_y[None,:,None,None,None]), # Height
x[:, :, :, :, 4:5].sigmoid(), # confidence
x[:, :, :, :, 5:], # classes (NOTE: no softmax here bc CEL is used later, which works on logits)
], -1)
conf = x[..., 4]
cls = x[..., 5:].reshape(-1, nC)
x = x[..., :4].detach() # gradients are tracked in coord -> not here anymore.
if self.seen < self.coord_prefill:
coord_mask.fill_(np.sqrt(.01 / self.lambda_coord))
tcoord[..., 0].fill_(0.5)
tcoord[..., 1].fill_(0.5)
for b in range(nB):
gt = y[b][(y[b, :, -1] >= 0)[:, None].expand_as(y[b])].view(-1, 6)[:,:4]
gt[:, ::2] *= nW
gt[:, 1::2] *= nH
if gt.numel() == 0: # no ground truth for this image
continue
# Set confidence mask of matching detections to 0
iou_gt_pred = iou(gt, x[b:(b+1)].view(-1, 4))
mask = (iou_gt_pred > self.threshold).sum(0) >= 1
conf_mask[b][mask.view_as(conf_mask[b])] = 0
# Find best anchor for each gt
iou_gt_anchors = iou_wh(gt[:,2:], anchors)
_, best_anchors = iou_gt_anchors.max(1)
# Set masks and target values for each gt
nGT = gt.shape[0]
gi = gt[:, 0].clamp(0, nW-1).long()
gj = gt[:, 1].clamp(0, nH-1).long()
conf_mask[b, best_anchors, gj, gi] = self.lambda_obj
tconf[b, best_anchors, gj, gi] = iou_gt_pred.view(nGT, nA, nH, nW)[torch.arange(nGT), best_anchors, gj, gi]
coord_mask[b, best_anchors, gj, gi, :] = (2 - (gt[:, 2] * gt[:, 3]) / nPixels)[..., None]
tcoord[b, best_anchors, gj, gi, 0] = gt[:, 0] - gi.float()
tcoord[b, best_anchors, gj, gi, 1] = gt[:, 1] - gj.float()
tcoord[b, best_anchors, gj, gi, 2] = (gt[:, 2] / anchors[best_anchors, 0]).log()
tcoord[b, best_anchors, gj, gi, 3] = (gt[:, 3] / anchors[best_anchors, 1]).log()
cls_mask[b, best_anchors, gj, gi] = 1
tcls[b, best_anchors, gj, gi] = y[b, torch.arange(nGT), -1]
coord_mask = coord_mask.sqrt()
conf_mask = conf_mask.sqrt()
tcls = tcls[cls_mask].view(-1).long()
cls_mask = cls_mask.view(-1, 1).expand(nB*nA*nH*nW, nC)
cls = cls[cls_mask].view(-1, nC)
loss_coord = self.lambda_coord * self.mse(coord*coord_mask, tcoord*coord_mask) / (2 * nB)
loss_conf = self.mse(conf*conf_mask, tconf*conf_mask) / (2 * nB)
loss_cls = self.lambda_cls * self.cel(cls, tcls) / nB
return loss_coord + loss_conf + loss_cls
lossfunc = YOLOLoss(anchors=network.anchors)
batch_output = network(batch_imgs, yolo=False)
lossfunc(batch_output, batch_labels)
Training¶
To check if the YOLO Loss actually works, we'll first randomly initialize the last layer of the TinyYOLOv2 network and then retrain this layer on the same VOC dataset:
for p in network.conv9.parameters():
try:
torch.nn.init.kaiming_normal_(p)
except ValueError:
torch.nn.init.normal_(p)
batch_predictions = network(batch_imgs)
show_images_with_boxes(batch_imgs, batch_labels)
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))
Let's now do the training:
batch_size = 64
all_idxs = np.array([int(fn.split("2008_")[-1].split(".jpg")[0]) for fn in sorted(glob.glob("VOCdevkit/VOC2012/JPEGImages/2008_*"))], dtype=int)
lossfunc = YOLOLoss(anchors=network.anchors, coord_prefill=int(5*all_idxs.shape[0]))
optimizer = torch.optim.Adam(network.conv9.parameters(), lr=0.003)
%%time
device=torch.device("cuda")
network = network.to(device)
np.random.RandomState(seed=42).shuffle(all_idxs)
valid_idxs = all_idxs[-4*batch_size:]
train_idxs = all_idxs[:-4*batch_size]
for e in range(20):
np.random.shuffle(train_idxs)
range_ = tqdm(np.array_split(train_idxs, batch_size))
with torch.no_grad():
valid_imgs = load_image_batch(valid_idxs, size=320).to(device)
valid_labels = load_bboxes_batch(valid_idxs, size=320, num_bboxes=10)
valid_predictions = network(valid_imgs, yolo=False)
valid_loss = lossfunc(valid_predictions, valid_labels).item()
range_.set_postfix(valid_loss=valid_loss)
for i, idxs in enumerate(range_):
optimizer.zero_grad()
batch_imgs = load_image_batch(idxs, size=320).to(device)
batch_labels = load_bboxes_batch(idxs, size=320, num_bboxes=10)
batch_predictions = network(batch_imgs, yolo=False)
loss = lossfunc(batch_predictions, batch_labels)
range_.set_postfix(loss=loss.item(), valid_loss=valid_loss)
loss.backward()
optimizer.step()
network = network.to("cpu")
imgs = load_image_batch([8, 16, 33, 60], size=320)
output_tensor = network(imgs)
filtered_tensor = filter_boxes(output_tensor, 0.3)
nms_tensor = nms(filtered_tensor, 0.3)
show_images_with_boxes(imgs, nms_tensor)
Index¶
- part 1: introduction to YOLO and TinyYOLOv2 implementation
- part 2 (this post): towards a YOLO loss function
- part 3 (in progress): transfer learning with YOLO
Credits¶
This part would not have been possible without the excellent lightnet implementation to guide me. Please check it out for a more complete PyTorch re- implementation of the YOLO architectures.
If you like this post, consider leaving a comment or star it on GitHub.