  toggle dark mode toggle dark mode

Towards a loss function for YOLO

posted on 2020-09-01T09:57:11Z · view page on GitHub
tags: ml (7)machine learning python (6) computer vision (3) yolo (2)

Writing the second part of the YOLO series took a lot longer than I care to admit... but, it's finally here! In this part we'll go over the definition of a loss function to train the YOLO architecture.

Index

In the previous post, we saw how to make bounding box predictions with YOLO. But how do they compare with the ground truth boxes? For this we need a way to load the ground truths:

In [2]:
def load_bboxes(idx, size, num_bboxes=None, device="cpu"):
    boxfilename = f"VOCdevkit/VOC2012/Annotations/2008_{str(idx).zfill(6)}.xml"
    imgfilename = f"VOCdevkit/VOC2012/JPEGImages/2008_{str(idx).zfill(6)}.jpg"
    img = Image.open(imgfilename) # img won't be loaded into memory, just needed for metadata.
    try:
        width, height = size
    except TypeError:
        width = height = size
    scale = min(width / img.width, height / img.height)
    new_width, new_height = int(img.width * scale), int(img.height * scale)
    diff_width  = (width - new_width)/width*img.width
    diff_height = (height - new_height)/height*img.height
    bboxes = []
    with open(boxfilename, "r") as file:
        for line in file:
            if "<name>" in line:
                class_ = line.split("<name>")[-1].split("</name>")[0].strip()
            elif "<xmin>" in line:
                x0 = float(line.split("<xmin>")[-1].split("</xmin>")[0].strip())
            elif "<xmax>" in line:
                x1 = float(line.split("<xmax>")[-1].split("</xmax>")[0].strip())
            elif "<ymin>" in line:
                y0 = float(line.split("<ymin>")[-1].split("</ymin>")[0].strip())
            elif "<ymax>" in line:
                y1 = float(line.split("<ymax>")[-1].split("</ymax>")[0].strip())
            elif "</object>" in line:
                if class_ not in CLASS2NUM:
                    continue
                bbox = [
                    (diff_width/2 + (x0+x1)/2) / (img.width + diff_width), # center x
                    (diff_height/2 + (y0+y1)/2) / (img.height + diff_height), # center y
                    (max(x0,x1) - min(x0,x1)) / (img.width + diff_width), # width
                    (max(y0,y1) - min(y0,y1)) / (img.height + diff_height), # height
                    1.0, # confidence
                    CLASS2NUM[class_], # class idx
                ]
                bboxes.append((bbox[2]*bbox[3]*(img.width+diff_width)*(img.height+diff_height),bbox))

    bboxes = torch.tensor([bbox for _, bbox in sorted(bboxes)], dtype=torch.get_default_dtype(), device=device)
    if num_bboxes:
        if num_bboxes > len(bboxes):
            zeros = torch.zeros((num_bboxes - bboxes.shape[0], 6), dtype=torch.get_default_dtype(), device=device)
            zeros[:, -1] = -1
            bboxes = torch.cat([bboxes, zeros], 0)
        elif num_bboxes < len(bboxes):
            bboxes = bboxes[:num_bboxes]

    return bboxes

def load_bboxes_batch(idxs, size, num_bboxes, device="cpu"):
    bboxes = [load_bboxes(idx, size, num_bboxes=num_bboxes, device="cpu") for idx in idxs]
    bboxes = torch.stack(bboxes, 0)
    return bboxes
In [3]:
batch_imgs = load_image_batch([8,16,33,60], size=320)
batch_labels = load_bboxes_batch([8,16,33,60], size=330, num_bboxes=10)
batch_predictions = network(batch_imgs)

print("labels:")
show_images_with_boxes(batch_imgs, batch_labels)
print("predictions:")
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))
labels:
predictions:

We can visually see that these labels correspond quite well to the predictions. But we need a way to quantify how well.

YOLO Loss function

Quantifying how well each predicted bounding box fits the bounding box labels can be done by defining a custom loss function. However, in the case of YOLO, this is a big deal. The loss function used to train YOLO from scratch is this monster (adapted for YOLOv2):

$$ \begin{align*} \lambda_\textbf{coord}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left[ \left( x_i - \hat{x}_i \right)^2 + \left( y_i - \hat{y}_i \right)^2 \right] \\ + \lambda_\textbf{coord}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left[ \left( \sqrt{w_i} - \sqrt{\hat{w}_i} \right)^2 + \left( \sqrt{h_i} - \sqrt{\hat{h}_i} \right)^2 \right] \\ + &\sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \left( C_i - \hat{C}_i \right)^2 \\ + \lambda_\textrm{noobj}& \sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{noobj}} \left( C_i - \hat{C}_i \right)^2 \\ + &\sum_{i = 0}^{M\times N} \sum_{j = 0}^{A} {{1}}_{ij}^{\text{obj}} \sum_{c \in \textrm{classes}} \left( p_i(c) - \hat{p}_i(c) \right)^2 \end{align*} $$

This loss function can be broken down as follows.

  • The summation $\sum_{i = 0}^{M\times N}$ that's present in every term of the loss is the sumation over each of the predictions (# anchors$\times$ # x-values $\times$# y-values).
  • The first term of this loss function describes the sum-squared error between the center coordinates of each predicted bounding box.
  • The second term describes the sum-squared error of the square root of the widths/heights. The square root is there to (partially) reduce the importance of variations on larger bounding boxes vs smaller bounding boxes.
  • The third term describes the sum-squared error in the confidences for boxes that do contain objects
  • The fourth term describes the sum-squared error in the confidences for boxes that don't contain objects
  • Finally, the last term describes the sum-squared error on the classification vector, but only if an object is present in the bounding box.
In [4]:
class YOLOLoss(torch.nn.modules.loss._Loss):
    """ A loss function to train YOLO v2

    Args:
        anchors (optional, list): the list of anchors (should be the same anchors as the ones defined in the YOLO class)
        seen (optional, torch.Tensor): the number of images the network has already been trained on
        coord_prefill (optional, int): the number of images for which the predicted bboxes will be centered in the image
        threshold (optional, float): minimum iou necessary to have a predicted bbox match a target bbox
        lambda_coord (optional, float): hyperparameter controlling the importance of the bbox coordinate predictions
        lambda_noobj (optional, float): hyperparameter controlling the importance of the bboxes containing no objects
        lambda_obj (optional, float): hyperparameter controlling the importance of the bboxes containing objects
        lambda_cls (optional, float): hyperparameter controlling the importance of the class prediction if the bbox contains an object
    """
    def __init__(
        self,
        anchors=(
            (1.08, 1.19),
            (3.42, 4.41),
            (6.63, 11.38),
            (9.42, 5.11),
            (16.62, 10.52),
        ),
        seen=0,
        coord_prefill=12800,
        threshold=0.6,
        lambda_coord=1.0,
        lambda_noobj=1.0,
        lambda_obj=5.0,
        lambda_cls=1.0,
    ):
        super().__init__()

        if not torch.is_tensor(anchors):
            anchors = torch.tensor(anchors, dtype=torch.get_default_dtype())
        else:
            anchors = anchors.data.to(torch.get_default_dtype())
        self.register_buffer("anchors", anchors)

        self.seen = int(seen+.5)
        self.coord_prefill = int(coord_prefill+.5)

        self.threshold = float(threshold)
        self.lambda_coord = float(lambda_coord)
        self.lambda_noobj = float(lambda_noobj)
        self.lambda_obj = float(lambda_obj)
        self.lambda_cls = float(lambda_cls)

        self.mse = torch.nn.MSELoss(reduction='sum')
        self.cel = torch.nn.CrossEntropyLoss(reduction='sum')

    def forward(self, x, y):
        nT = y.shape[1]
        nA = self.anchors.shape[0]
        nB, _, nH, nW = x.shape
        nPixels = nH * nW
        nAnchors = nA * nPixels
        y = y.to(dtype=x.dtype, device=x.device)
        x = x.view(nB, nA, -1, nH, nW).permute(0, 1, 3, 4, 2)
        nC = x.shape[-1] - 5
        self.seen += nB

        anchors = self.anchors.to(dtype=x.dtype, device=x.device)
        coord_mask = torch.zeros(nB, nA, nH, nW, 1, requires_grad=False, dtype=x.dtype, device=x.device)
        conf_mask = torch.ones(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device) * self.lambda_noobj
        cls_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=torch.bool, device=x.device)
        tcoord = torch.zeros(nB, nA, nH, nW, 4, requires_grad=False, dtype=x.dtype, device=x.device)
        tconf = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device)
        tcls = torch.zeros(nB, nA, nH, nW, requires_grad=False, dtype=x.dtype, device=x.device)

        coord = torch.cat([
            x[:, :, :, :, 0:1].sigmoid(),  # X center
            x[:, :, :, :, 1:2].sigmoid(),  # Y center
            x[:, :, :, :, 2:3],  # Width
            x[:, :, :, :, 3:4],  # Height
        ], -1)

        range_y, range_x = torch.meshgrid(
            torch.arange(nH, dtype=x.dtype, device=x.device),
            torch.arange(nW, dtype=x.dtype, device=x.device),
        )
        anchor_x, anchor_y = anchors[:, 0], anchors[:, 1]

        x = torch.cat([
            (x[:, :, :, :, 0:1].sigmoid() + range_x[None,None,:,:,None]),  # X center
            (x[:, :, :, :, 1:2].sigmoid() + range_y[None,None,:,:,None]),  # Y center
            (x[:, :, :, :, 2:3].exp() * anchor_x[None,:,None,None,None]),  # Width
            (x[:, :, :, :, 3:4].exp() * anchor_y[None,:,None,None,None]),  # Height
            x[:, :, :, :, 4:5].sigmoid(), # confidence
            x[:, :, :, :, 5:], # classes (NOTE: no softmax here bc CEL is used later, which works on logits)
        ], -1)

        conf = x[..., 4]
        cls = x[..., 5:].reshape(-1, nC)
        x = x[..., :4].detach() # gradients are tracked in coord -> not here anymore.

        if self.seen < self.coord_prefill:
            coord_mask.fill_(np.sqrt(.01 / self.lambda_coord))
            tcoord[..., 0].fill_(0.5)
            tcoord[..., 1].fill_(0.5)

        for b in range(nB):
            gt = y[b][(y[b, :, -1] >= 0)[:, None].expand_as(y[b])].view(-1, 6)[:,:4]
            gt[:, ::2] *= nW
            gt[:, 1::2] *= nH
            if gt.numel() == 0:  # no ground truth for this image
                continue

            # Set confidence mask of matching detections to 0
            iou_gt_pred = iou(gt, x[b:(b+1)].view(-1, 4))
            mask = (iou_gt_pred > self.threshold).sum(0) >= 1
            conf_mask[b][mask.view_as(conf_mask[b])] = 0

            # Find best anchor for each gt
            iou_gt_anchors = iou_wh(gt[:,2:], anchors)
            _, best_anchors = iou_gt_anchors.max(1)

            # Set masks and target values for each gt
            nGT = gt.shape[0]
            gi = gt[:, 0].clamp(0, nW-1).long()
            gj = gt[:, 1].clamp(0, nH-1).long()

            conf_mask[b, best_anchors, gj, gi] = self.lambda_obj
            tconf[b, best_anchors, gj, gi] = iou_gt_pred.view(nGT, nA, nH, nW)[torch.arange(nGT), best_anchors, gj, gi]
            coord_mask[b, best_anchors, gj, gi, :] = (2 - (gt[:, 2] * gt[:, 3]) / nPixels)[..., None]
            tcoord[b, best_anchors, gj, gi, 0] = gt[:, 0] - gi.float()
            tcoord[b, best_anchors, gj, gi, 1] = gt[:, 1] - gj.float()
            tcoord[b, best_anchors, gj, gi, 2] = (gt[:, 2] / anchors[best_anchors, 0]).log()
            tcoord[b, best_anchors, gj, gi, 3] = (gt[:, 3] / anchors[best_anchors, 1]).log()
            cls_mask[b, best_anchors, gj, gi] = 1
            tcls[b, best_anchors, gj, gi] = y[b, torch.arange(nGT), -1]

        coord_mask = coord_mask.sqrt()
        conf_mask = conf_mask.sqrt()
        tcls = tcls[cls_mask].view(-1).long()
        cls_mask = cls_mask.view(-1, 1).expand(nB*nA*nH*nW, nC)
        cls = cls[cls_mask].view(-1, nC)

        loss_coord = self.lambda_coord * self.mse(coord*coord_mask, tcoord*coord_mask) / (2 * nB)
        loss_conf = self.mse(conf*conf_mask, tconf*conf_mask) / (2 * nB)
        loss_cls = self.lambda_cls * self.cel(cls, tcls) / nB
        return loss_coord + loss_conf + loss_cls
In [5]:
lossfunc = YOLOLoss(anchors=network.anchors)
batch_output = network(batch_imgs, yolo=False)
lossfunc(batch_output, batch_labels)
Out[5]:
tensor(7.6441, grad_fn=<AddBackward0>)

Training

To check if the YOLO Loss actually works, we'll first randomly initialize the last layer of the TinyYOLOv2 network and then retrain this layer on the same VOC dataset:

In [6]:
for p in network.conv9.parameters():
    try:
        torch.nn.init.kaiming_normal_(p)
    except ValueError:
        torch.nn.init.normal_(p)
batch_predictions = network(batch_imgs)
show_images_with_boxes(batch_imgs, batch_labels)
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))

Let's now do the training:

In [7]:
batch_size = 64
all_idxs = np.array([int(fn.split("2008_")[-1].split(".jpg")[0]) for fn in sorted(glob.glob("VOCdevkit/VOC2012/JPEGImages/2008_*"))], dtype=int)
lossfunc = YOLOLoss(anchors=network.anchors, coord_prefill=int(5*all_idxs.shape[0]))
optimizer = torch.optim.Adam(network.conv9.parameters(), lr=0.003)
In [8]:
%%time
device=torch.device("cuda")
network = network.to(device)
np.random.RandomState(seed=42).shuffle(all_idxs)
valid_idxs = all_idxs[-4*batch_size:]
train_idxs = all_idxs[:-4*batch_size]

for e in range(20):
    np.random.shuffle(train_idxs)
    range_ = tqdm(np.array_split(train_idxs, batch_size))
    with torch.no_grad():
        valid_imgs = load_image_batch(valid_idxs, size=320).to(device)
        valid_labels = load_bboxes_batch(valid_idxs, size=320, num_bboxes=10)
        valid_predictions = network(valid_imgs, yolo=False)
        valid_loss = lossfunc(valid_predictions, valid_labels).item()
        range_.set_postfix(valid_loss=valid_loss)
    for i, idxs in enumerate(range_):
        optimizer.zero_grad()
        batch_imgs = load_image_batch(idxs, size=320).to(device)
        batch_labels = load_bboxes_batch(idxs, size=320, num_bboxes=10)
        batch_predictions = network(batch_imgs, yolo=False)
        loss = lossfunc(batch_predictions, batch_labels)
        range_.set_postfix(loss=loss.item(), valid_loss=valid_loss)
        loss.backward()
        optimizer.step()
CPU times: user 4h 43min 52s, sys: 8min 6s, total: 4h 51min 59s
Wall time: 19min 17s
In [9]:
network = network.to("cpu")
imgs = load_image_batch([8, 16, 33, 60], size=320)
output_tensor = network(imgs)
filtered_tensor = filter_boxes(output_tensor, 0.3)
nms_tensor = nms(filtered_tensor, 0.3)
show_images_with_boxes(imgs, nms_tensor)

Index

Credits

This part would not have been possible without the excellent lightnet implementation to guide me. Please check it out for a more complete PyTorch re- implementation of the YOLO architectures.



If you like this post, consider leaving a comment or star it on GitHub.