Towards a loss function for YOLO
posted on 2020-09-01T09:57:11Z · view page on GitHubWriting the second part of the YOLO series took a lot longer than I care to admit... but, it's finally here! In this part we'll go over the definition of a loss function to train the YOLO architecture.
Index¶
- part 1: introduction to YOLO and TinyYOLOv2 implementation
- part 2 (this post): towards a YOLO loss function
- part 3 (in progress): transfer learning with YOLO
In the previous post, we saw how to make bounding box predictions with YOLO. But how do they compare with the ground truth boxes? For this we need a way to load the ground truths:
In [2]:
def load_bboxes(idx, size, num_bboxes=None, device="cpu"):
boxfilename = f"VOCdevkit/VOC2012/Annotations/2008_{str(idx).zfill(6)}.xml"
imgfilename = f"VOCdevkit/VOC2012/JPEGImages/2008_{str(idx).zfill(6)}.jpg"
img = Image.open(imgfilename) # img won't be loaded into memory, just needed for metadata.
try:
width, height = size
except TypeError:
width = height = size
scale = min(width / img.width, height / img.height)
new_width, new_height = int(img.width * scale), int(img.height * scale)
diff_width = (width - new_width)/width*img.width
diff_height = (height - new_height)/height*img.height
bboxes = []
with open(boxfilename, "r") as file:
for line in file:
if "<name>" in line:
class_ = line.split("<name>")[-1].split("</name>")[0].strip()
elif "<xmin>" in line:
x0 = float(line.split("<xmin>")[-1].split("</xmin>")[0].strip())
elif "<xmax>" in line:
x1 = float(line.split("<xmax>")[-1].split("</xmax>")[0].strip())
elif "<ymin>" in line:
y0 = float(line.split("<ymin>")[-1].split("</ymin>")[0].strip())
elif "<ymax>" in line:
y1 = float(line.split("<ymax>")[-1].split("</ymax>")[0].strip())
elif "</object>" in line:
if class_ not in CLASS2NUM:
continue
bbox = [
(diff_width/2 + (x0+x1)/2) / (img.width + diff_width), # center x
(diff_height/2 + (y0+y1)/2) / (img.height + diff_height), # center y
(max(x0,x1) - min(x0,x1)) / (img.width + diff_width), # width
(max(y0,y1) - min(y0,y1)) / (img.height + diff_height), # height
1.0, # confidence
CLASS2NUM[class_], # class idx
]
bboxes.append((bbox[2]*bbox[3]*(img.width+diff_width)*(img.height+diff_height),bbox))
bboxes = torch.tensor([bbox for _, bbox in sorted(bboxes)], dtype=torch.get_default_dtype(), device=device)
if num_bboxes:
if num_bboxes > len(bboxes):
zeros = torch.zeros((num_bboxes - bboxes.shape[0], 6), dtype=torch.get_default_dtype(), device=device)
zeros[:, -1] = -1
bboxes = torch.cat([bboxes, zeros], 0)
elif num_bboxes < len(bboxes):
bboxes = bboxes[:num_bboxes]
return bboxes
def load_bboxes_batch(idxs, size, num_bboxes, device="cpu"):
bboxes = [load_bboxes(idx, size, num_bboxes=num_bboxes, device="cpu") for idx in idxs]
bboxes = torch.stack(bboxes, 0)
return bboxes
In [3]:
batch_imgs = load_image_batch([8,16,33,60], size=320)
batch_labels = load_bboxes_batch([8,16,33,60], size=330, num_bboxes=10)
batch_predictions = network(batch_imgs)
print("labels:")
show_images_with_boxes(batch_imgs, batch_labels)
print("predictions:")
show_images_with_boxes(batch_imgs, nms(filter_boxes(batch_predictions, 0.2), 0.5))