# Towards a loss function for YOLO

posted on 2020-09-01T09:57:11Z · view page on GitHub
tags:

Writing the second part of the YOLO series took a lot longer than I care to admit... but, it's finally here! In this part we'll go over the definition of a loss function to train the YOLO architecture.

## Index¶

In the previous post, we saw how to make bounding box predictions with YOLO. But how do they compare with the ground truth boxes? For this we need a way to load the ground truths:

In [2]:
def load_bboxes(idx, size, num_bboxes=None, device="cpu"):
boxfilename = f"VOCdevkit/VOC2012/Annotations/2008_{str(idx).zfill(6)}.xml"
imgfilename = f"VOCdevkit/VOC2012/JPEGImages/2008_{str(idx).zfill(6)}.jpg"
img = Image.open(imgfilename) # img won't be loaded into memory, just needed for metadata.
try:
width, height = size
except TypeError:
width = height = size
scale = min(width / img.width, height / img.height)
new_width, new_height = int(img.width * scale), int(img.height * scale)
diff_width  = (width - new_width)/width*img.width
diff_height = (height - new_height)/height*img.height
bboxes = []
with open(boxfilename, "r") as file:
for line in file:
if "<name>" in line:
class_ = line.split("<name>")[-1].split("</name>")[0].strip()
elif "<xmin>" in line:
x0 = float(line.split("<xmin>")[-1].split("</xmin>")[0].strip())
elif "<xmax>" in line:
x1 = float(line.split("<xmax>")[-1].split("</xmax>")[0].strip())
elif "<ymin>" in line:
y0 = float(line.split("<ymin>")[-1].split("</ymin>")[0].strip())
elif "<ymax>" in line:
y1 = float(line.split("<ymax>")[-1].split("</ymax>")[0].strip())
elif "</object>" in line:
if class_ not in CLASS2NUM:
continue
bbox = [
(diff_width/2 + (x0+x1)/2) / (img.width + diff_width), # center x
(diff_height/2 + (y0+y1)/2) / (img.height + diff_height), # center y
(max(x0,x1) - min(x0,x1)) / (img.width + diff_width), # width
(max(y0,y1) - min(y0,y1)) / (img.height + diff_height), # height
1.0, # confidence
CLASS2NUM[class_], # class idx
]
bboxes.append((bbox[2]*bbox[3]*(img.width+diff_width)*(img.height+diff_height),bbox))

bboxes = torch.tensor([bbox for _, bbox in sorted(bboxes)], dtype=torch.get_default_dtype(), device=device)
if num_bboxes:
if num_bboxes > len(bboxes):
zeros = torch.zeros((num_bboxes - bboxes.shape[0], 6), dtype=torch.get_default_dtype(), device=device)
zeros[:, -1] = -1
bboxes = torch.cat([bboxes, zeros], 0)
elif num_bboxes < len(bboxes):
bboxes = bboxes[:num_bboxes]

return bboxes

bboxes = [load_bboxes(idx, size, num_bboxes=num_bboxes, device="cpu") for idx in idxs]
bboxes = torch.stack(bboxes, 0)
return bboxes

In [3]:
batch_imgs = load_image_batch([8,16,33,60], size=320)

labels:

predictions: