  toggle dark mode toggle dark mode

Towards a loss function for YOLO

posted on 2020-09-01T09:57:11Z · view page on GitHub
tags: ml (7)machine learning python (6) cv (3)computer vision yolo (2)

Writing the second part of the YOLO series took a lot longer than I care to admit... but, it's finally here! In this part we'll go over the definition of a loss function to train the YOLO architecture.


In the previous post, we saw how to make bounding box predictions with YOLO. But how do they compare with the ground truth boxes? For this we need a way to load the ground truths:

In [2]:
def load_bboxes(idx, size, num_bboxes=None, device="cpu"):
    boxfilename = f"VOCdevkit/VOC2012/Annotations/2008_{str(idx).zfill(6)}.xml"
    imgfilename = f"VOCdevkit/VOC2012/JPEGImages/2008_{str(idx).zfill(6)}.jpg"
    img = Image.open(imgfilename) # img won't be loaded into memory, just needed for metadata.
        width, height = size
    except TypeError:
        width = height = size
    scale = min(width / img.width, height / img.height)
    new_width, new_height = int(img.width * scale), int(img.height * scale)
    diff_width  = (width - new_width)/width*img.width
    diff_height = (height - new_height)/height*img.height
    bboxes = []
    with open(boxfilename, "r") as file:
        for line in file:
            if "<name>" in line:
                class_ = line.split("<name>")[-1].split("</name>")[0].strip()
            elif "<xmin>" in line:
                x0 = float(line.split("<xmin>")[-1].split("</xmin>")[0].strip())
            elif "<xmax>" in line:
                x1 = float(line.split("<xmax>")[-1].split("</xmax>")[0].strip())
            elif "<ymin>" in line:
                y0 = float(line.split("<ymin>")[-1].split("</ymin>")[0].strip())
            elif "<ymax>" in line:
                y1 = float(line.split("<ymax>")[-1].split("</ymax>")[0].strip())
            elif "</object>" in line:
                if class_ not in CLASS2NUM:
                bbox = [
                    (diff_width/2 + (x0+x1)/2) / (img.width + diff_width), # center x
                    (diff_height/2 + (y0+y1)/2) / (img.height + diff_height), # center y
                    (max(x0,x1) - min(x0,x1)) / (img.width + diff_width), # width
                    (max(y0,y1) - min(y0,y1)) / (img.height + diff_height), # height
                    1.0, # confidence
                    CLASS2NUM[class_], # class idx

    bboxes = torch.tensor([bbox for _, bbox in sorted(bboxes)], dtype=torch.get_default_dtype(), device=device)
    if num_bboxes:
        if num_bboxes > len(bboxes):
            zeros = torch.zeros((num_bboxes - bboxes.shape[0], 6), dtype=torch.get_default_dtype(), device=device)
            zeros[:, -1] = -1
            bboxes = torch.cat([bboxes, zeros], 0)
        elif num_bboxes < len(bboxes):
            bboxes = bboxes[:num_bboxes]

    return bboxes

def load_bboxes_batch(idxs, size, num_bboxes, device="cpu"):
    bboxes = [load_bboxes(idx, size, num_bboxes=num_bboxes, device="cpu") for idx in idxs]
    bboxes = torch.stack(bboxes, 0)
    return bboxes
In [3]:
batch_imgs = load_image_batch([8,16,33,60], size=320)
batch_labels = load_bboxes_batch([8,16,33,60], size=330, num_bboxes=10)
batch_predictions = network(batch_imgs)

show_images_with_boxes(batch_imgs, batch_labels)
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))