In [1]:
import torch
BATCH_SIZE = 8 # Increase / decrease according to GPU memeory.
RESIZE_TO = 640 # Resize the image for training and transforms.
NUM_EPOCHS = 30 # Number of epochs to train for.
NUM_WORKERS = 4 # Number of parallel workers for data loading.
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Training images and labels files directory.
TRAIN_DIR = "/content/drive/MyDrive/Colab Notebooks/data/train"
# Validation images and labels files directory.
VALID_DIR = "/content/drive/MyDrive/Colab Notebooks/data/valid"
# Classes: 0 index is reserved for background.
CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]
NUM_CLASSES = len(CLASSES)
# Whether to visualize images after crearing the data loaders.
VISUALIZE_TRANSFORMED_IMAGES = True
# Location to save model and plots.
OUT_DIR = "outputs"
In [2]:
!pip install config
Collecting config Downloading config-0.5.1-py2.py3-none-any.whl.metadata (1.4 kB) Downloading config-0.5.1-py2.py3-none-any.whl (20 kB) Installing collected packages: config Successfully installed config-0.5.1
In [3]:
def collate_fn(batch):
"""
To handle the data loading as different images may have different
numbers of objects, and to handle varying-size tensors as well.
"""
return tuple(zip(*batch))
def get_train_transform():
# We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
return A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=45),
A.Blur(blur_limit=3, p=0.2),
A.MotionBlur(blur_limit=3, p=0.1),
A.MedianBlur(blur_limit=3, p=0.1),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
A.RandomScale(scale_limit=0.2, p=0.3),
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
def get_valid_transform():
return A.Compose(
[
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
In [4]:
import torch
import cv2
import numpy as np
import os
import glob
#from config import CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE
from torch.utils.data import Dataset, DataLoader
#from custom_utils import collate_fn, get_train_transform, get_valid_transform
class CustomDataset(Dataset):
def __init__(self, dir_path, width, height, classes, transforms=None):
"""
:param dir_path: Directory containing 'images/' and 'labels/' subfolders.
:param width: Resized image width.
:param height: Resized image height.
:param classes: List of class names (or an indexing scheme).
:param transforms: Albumentations transformations to apply.
"""
self.transforms = transforms
self.dir_path = dir_path
self.image_dir = os.path.join(self.dir_path, "images")
self.label_dir = os.path.join(self.dir_path, "labels")
self.width = width
self.height = height
self.classes = classes
# Gather all image paths
self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
self.all_image_paths = []
for file_type in self.image_file_types:
self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))
# Sort for consistent ordering
self.all_image_paths = sorted(self.all_image_paths)
self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]
def __len__(self):
return len(self.all_image_paths)
def __getitem__(self, idx):
# 1) Read image
image_name = self.all_image_names[idx]
image_path = os.path.join(self.image_dir, image_name)
label_filename = os.path.splitext(image_name)[0] + ".txt"
label_path = os.path.join(self.label_dir, label_filename)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
# 2) Resize image (to the model's expected size)
image_resized = cv2.resize(image, (self.width, self.height))
image_resized /= 255.0 # Scale pixel values to [0, 1]
# 3) Read bounding boxes (normalized) from .txt file
boxes = []
labels = []
if os.path.exists(label_path):
with open(label_path, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line:
continue
# Format: class_id x_min y_min x_max y_max (all in [0..1])
parts = line.split()
class_id = int(parts[0]) # e.g. 0, 1, 2, ...
xmin = float(parts[1])
ymin = float(parts[2])
xmax = float(parts[3])
ymax = float(parts[4])
# Example: if you want class IDs to start at 1 for foreground
# and background=0, do:
label_idx = class_id + 1
# Convert normalized coords to absolute (in resized space)
x_min_final = xmin * self.width
y_min_final = ymin * self.height
x_max_final = xmax * self.width
y_max_final = ymax * self.height
# Ensure valid box
if x_max_final <= x_min_final:
x_max_final = x_min_final + 1
if y_max_final <= y_min_final:
y_max_final = y_min_final + 1
# Clip if out of bounds
x_min_final = max(0, min(x_min_final, self.width - 1))
x_max_final = max(0, min(x_max_final, self.width))
y_min_final = max(0, min(y_min_final, self.height - 1))
y_max_final = max(0, min(y_max_final, self.height))
boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
labels.append(label_idx)
# 4) Convert boxes & labels to Torch tensors
if len(boxes) == 0:
boxes = torch.zeros((0, 4), dtype=torch.float32)
labels = torch.zeros((0,), dtype=torch.int64)
else:
boxes = torch.tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
# 5) Prepare the target dict
area = (
(boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
if len(boxes) > 0
else torch.tensor([], dtype=torch.float32)
)
iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
image_id = torch.tensor([idx])
target = {"boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd, "image_id": image_id}
# 6) Albumentations transforms: pass Python lists, not Tensors
if self.transforms:
bboxes_list = boxes.cpu().numpy().tolist() # shape: list of [xmin, ymin, xmax, ymax]
labels_list = labels.cpu().numpy().tolist() # shape: list of ints
transformed = self.transforms(
image=image_resized,
bboxes=bboxes_list,
labels=labels_list,
)
# Reassign the image
image_resized = transformed["image"]
# Convert bboxes back to Torch Tensors
new_bboxes_list = transformed["bboxes"] # list of [xmin, ymin, xmax, ymax]
new_labels_list = transformed["labels"] # list of int
if len(new_bboxes_list) > 0:
new_bboxes = torch.tensor(new_bboxes_list, dtype=torch.float32)
new_labels = torch.tensor(new_labels_list, dtype=torch.int64)
else:
new_bboxes = torch.zeros((0, 4), dtype=torch.float32)
new_labels = torch.zeros((0,), dtype=torch.int64)
target["boxes"] = new_bboxes
target["labels"] = new_labels
return image_resized, target
In [5]:
# ---------------------------------------------------------
# Debug/demo if run directly
# ---------------------------------------------------------
if __name__ == "__main__":
# Example usage with no transforms for debugging
dataset = CustomDataset(dir_path=TRAIN_DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=None)
print(f"Number of training images: {len(dataset)}")
if len(dataset) > 0:
from google.colab.patches import cv2_imshow # Import cv2_imshow
def visualize_sample(image, target):
"""
Visualize a single sample using OpenCV. Expects
`image` as a NumPy array of shape (H, W, 3) in [0..1].
"""
# Convert [0,1] float -> [0,255] uint8
img = (image * 255).astype(np.uint8)
# Convert RGB -> BGR
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
boxes = target["boxes"].cpu().numpy().astype(np.int32)
labels = target["labels"].cpu().numpy().astype(np.int32)
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box
class_idx = labels[i]
# If your class_idx starts at 1 for "first class", ensure you handle that:
# e.g. if CLASSES = ["background", "class1", "class2", ...]
if 0 <= class_idx < len(CLASSES):
class_str = CLASSES[class_idx]
else:
class_str = f"Label_{class_idx}"
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(
img, class_str, (x1, max(y1 - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2
)
cv2_imshow(img) # Use cv2_imshow instead of cv2.imshow
cv2.waitKey(0)
# Visualize a few samples
NUM_SAMPLES_TO_VISUALIZE = 10
for i in range(min(NUM_SAMPLES_TO_VISUALIZE, len(dataset))): # Also adjust loop range
image, target = dataset[i] # No transforms in this example
# `image` is shape (H, W, 3) in [0..1]
print(f"Visualizing sample {i}, boxes: {target['boxes'].shape[0]}")
visualize_sample(image, target)
cv2.destroyAllWindows()
else:
print("Dataset is empty. Cannot visualize samples.")
Output hidden; open in https://colab.research.google.com to view.
In [ ]:
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=True,
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=True,
)
return valid_loader
In [6]:
!pip install torchmetrics
Collecting torchmetrics Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB) Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.0.2) Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (25.0) Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.8.0+cu126) Collecting lightning-utilities>=0.8.0 (from torchmetrics) Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (75.2.0) Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.14.1) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.19.1) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (1.13.3) Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.5) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.1.6) Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2025.3.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2.27.3) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (1.11.1.6) Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.4.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->torchmetrics) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->torchmetrics) (3.0.2) Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 983.0/983.0 kB 40.4 MB/s eta 0:00:00 Downloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB) Installing collected packages: lightning-utilities, torchmetrics Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.1
In [7]:
from tqdm.auto import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch
import matplotlib.pyplot as plt
import time
import os
In [8]:
plt.style.use("ggplot")
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
In [9]:
# Function for running training iterations.
def train(train_data_loader, model):
print("Training")
model.train()
# initialize tqdm progress bar
prog_bar = tqdm(train_data_loader, total=len(train_data_loader))
for i, data in enumerate(prog_bar):
optimizer.zero_grad()
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
train_loss_hist.send(loss_value)
losses.backward()
optimizer.step()
# update the loss value beside the progress bar for each iteration
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return loss_value
In [10]:
from tqdm.auto import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch
import matplotlib.pyplot as plt
import time
import os
import torchvision
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
import cv2
import numpy as np
import glob
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from google.colab.patches import cv2_imshow # Import cv2_imshow
# Define functions here
def collate_fn(batch):
"""
To handle the data loading as different images may have different
numbers of objects, and to handle varying-size tensors as well.
"""
return tuple(zip(*batch))
def get_train_transform():
# We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
return A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=45),
A.Blur(blur_limit=3, p=0.2),
A.MotionBlur(blur_limit=3, p=0.1),
A.MedianBlur(blur_limit=3, p=0.1),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
A.RandomScale(scale_limit=0.2, p=0.3),
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
def get_valid_transform():
return A.Compose(
[
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
class CustomDataset(Dataset):
def __init__(self, dir_path, width, height, classes, transforms=None):
"""
:param dir_path: Directory containing 'images/' and 'labels/' subfolders.
:param width: Resized image width.
:param height: Resized image height.
:param classes: List of class names (or an indexing scheme).
:param transforms: Albumentations transformations to apply.
"""
self.transforms = transforms
self.dir_path = dir_path
self.image_dir = os.path.join(self.dir_path, "images")
self.label_dir = os.path.join(self.dir_path, "labels")
self.width = width
self.height = height
self.classes = classes
# Gather all image paths
self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
self.all_image_paths = []
for file_type in self.image_file_types:
self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))
# Sort for consistent ordering
self.all_image_paths = sorted(self.all_image_paths)
self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]
def __len__(self):
return len(self.all_image_paths)
def __getitem__(self, idx):
# 1) Read image
image_name = self.all_image_names[idx]
image_path = os.path.join(self.image_dir, image_name)
label_filename = os.path.splitext(image_name)[0] + ".txt"
label_path = os.path.join(self.label_dir, label_filename)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
# 2) Resize image (to the model's expected size)
image_resized = cv2.resize(image, (self.width, self.height))
image_resized /= 255.0 # Scale pixel values to [0, 1]
# 3) Read bounding boxes (normalized) from .txt file
boxes = []
labels = []
if os.path.exists(label_path):
with open(label_path, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line:
continue
# Format: class_id x_min y_min x_max y_max (all in [0..1])
parts = line.split()
class_id = int(parts[0]) # e.g. 0, 1, 2, ...
# Example: if you want class IDs to start at 1 for foreground
# and background=0, do:
label_idx = class_id + 1
xmin = float(parts[1])
ymin = float(parts[2])
xmax = float(parts[3])
ymax = float(parts[4])
# Convert normalized coords to absolute (in resized space)
x_min_final = xmin * self.width
y_min_final = ymin * self.height
x_max_final = xmax * self.width
y_max_final = ymax * self.height
# Ensure valid box
if x_max_final <= x_min_final:
x_max_final = x_min_final + 1
if y_max_final <= y_min_final:
y_max_final = y_min_final + 1
# Clip if out of bounds
x_min_final = max(0, min(x_min_final, self.width - 1))
x_max_final = max(0, min(x_max_final, self.width))
y_min_final = max(0, min(y_min_final, self.height - 1))
y_max_final = max(0, min(y_max_final, self.height))
boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
labels.append(label_idx)
# 4) Convert boxes & labels to Torch tensors
if len(boxes) == 0:
boxes = torch.zeros((0, 4), dtype=torch.float32)
labels = torch.zeros((0,), dtype=torch.int64)
else:
boxes = torch.tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
# 5) Prepare the target dict
area = (
(boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
if len(boxes) > 0
else torch.tensor([], dtype=torch.float32)
)
iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
image_id = torch.tensor([idx])
target = {"boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd, "image_id": image_id}
# 6) Albumentations transforms: pass Python lists, not Tensors
if self.transforms:
bboxes_list = boxes.cpu().numpy().tolist() # shape: list of [xmin, ymin, xmax, ymax]
labels_list = labels.cpu().numpy().tolist() # shape: list of ints
transformed = self.transforms(
image=image_resized,
bboxes=bboxes_list,
labels=labels_list,
)
# Reassign the image
image_resized = transformed["image"]
# Convert bboxes back to Torch Tensors
new_bboxes_list = transformed["bboxes"] # list of [xmin, ymin, xmax, ymax]
new_labels_list = transformed["labels"] # list of int
if len(new_bboxes_list) > 0:
new_bboxes = torch.tensor(new_bboxes_list, dtype=torch.float32)
new_labels = torch.tensor(new_labels_list, dtype=torch.int64)
else:
new_bboxes = torch.zeros((0, 4), dtype=torch.float32)
new_labels = torch.zeros((0,), dtype=torch.int64)
target["boxes"] = new_bboxes
target["labels"] = new_labels
return image_resized, target
# ---------------------------------------------------------
# Create train/valid datasets and loaders
# ---------------------------------------------------------
def create_train_dataset(DIR, RESIZE_TO, CLASSES, get_train_transform):
train_dataset = CustomDataset(
dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_train_transform()
)
return train_dataset
def create_valid_dataset(DIR, RESIZE_TO, CLASSES, get_valid_transform):
valid_dataset = CustomDataset(
dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_valid_transform()
)
return valid_dataset
def create_train_loader(train_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
collate_fn=collate_fn,
drop_last=True,
)
return train_loader
def create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
collate_fn=collate_fn,
drop_last=True,
)
return valid_loader
class Averager:
"""
A class to keep track of running average of values (e.g. training loss).
"""
def __init__(self):
self.current_total = 0.0
self.iterations = 0.0
def send(self, value):
self.current_total += value
self.iterations += 1
@property
def value(self):
if self.iterations == 0:
return 0
else:
return self.current_total / self.iterations
def reset(self):
self.current_total = 0.0
self.iterations = 0.0
class SaveBestModel:
"""
Saves the model if the current epoch's validation mAP is higher
than all previously observed values.
"""
def __init__(self, best_valid_map=float(0)):
self.best_valid_map = best_valid_map
def __call__(
self,
model,
current_valid_map,
epoch,
OUT_DIR,
):
if current_valid_map > self.best_valid_map:
self.best_valid_map = current_valid_map
print(f"\nBEST VALIDATION mAP: {self.best_valid_map}")
print(f"SAVING BEST MODEL FOR EPOCH: {epoch+1}\n")
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
},
f"{OUT_DIR}/best_model.pth",
)
def save_model(epoch, model, optimizer, OUT_DIR):
"""
Save the trained model (state dict) and optimizer state to disk.
"""
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
f"{OUT_DIR}/last_model.pth",
)
def save_loss_plot(OUT_DIR, train_loss_list, x_label="iterations", y_label="train loss", save_name="train_loss"):
"""
Saves the training loss curve.
"""
plt.figure(figsize=(10, 7))
plt.plot(train_loss_list, color="tab:blue")
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.savefig(f"{OUT_DIR}/{save_name}.png")
# plt.close()
print("SAVING PLOTS COMPLETE...")
def save_mAP(OUT_DIR, map_05, map):
"""
Saves the mAP@0.5 and mAP@0.5:0.95 curves per epoch.
"""
plt.figure(figsize=(10, 7))
plt.plot(map_05, color="tab:orange", linestyle="-", label="mAP@0.5")
plt.plot(map, color="tab:red", linestyle="-", label="mAP@0.5:0.95")
plt.xlabel("Epochs")
plt.ylabel("mAP")
plt.legend()
plt.savefig(f"{OUT_DIR}/map.png")
# plt.close()
print("SAVING mAP PLOTS COMPLETE...")
def create_model(num_classes=91):
"""
Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
Replaces the classification head for the required number of classes.
"""
model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
num_anchors = model.head.classification_head.num_anchors
# Replace the classification head
model.head.classification_head = RetinaNetClassificationHead(
in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
)
return model
def train(train_data_loader, model, optimizer, train_loss_hist, DEVICE):
print("Training")
model.train()
# initialize tqdm progress bar
prog_bar = tqdm(train_data_loader, total=len(train_data_loader))
for i, data in enumerate(prog_bar):
optimizer.zero_grad()
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
train_loss_hist.send(loss_value)
losses.backward()
optimizer.step()
# update the loss value beside the progress bar for each iteration
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return loss_value
def validate(valid_loader, model, DEVICE):
"""
Function to perform validation on the validation dataset.
Returns the mAP values.
"""
print("Validating")
# Initialize the metric.
metric = MeanAveragePrecision()
# Set the model to evaluation mode.
model.eval()
# Initialize tqdm progress bar.
prog_bar = tqdm(valid_loader, total=len(valid_loader))
with torch.no_grad():
for i, data in enumerate(prog_bar):
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
outputs = model(images)
# For torchmetrics, we need to format the predictions and targets.
# Predictions should be a list of dicts, each dict containing:
# boxes (FloatTensor[N, 4]), scores (FloatTensor[N]), labels (IntTensor[N])
# Targets should also be a list of dicts, each dict containing:
# boxes (FloatTensor[M, 4]), labels (IntTensor[M])
formatted_preds = []
formatted_targets = []
for j in range(len(outputs)):
formatted_preds.append({
"boxes": outputs[j]["boxes"],
"scores": outputs[j]["scores"],
"labels": outputs[j]["labels"],
})
formatted_targets.append({
"boxes": targets[j]["boxes"],
"labels": targets[j]["labels"],
})
metric.update(formatted_preds, formatted_targets)
# Compute the metrics.
metric_summary = metric.compute()
return metric_summary
def show_tranformed_image(train_loader, CLASSES, DEVICE):
"""
Visualize transformed images from the `train_loader` for debugging.
Only runs if `VISUALIZE_TRANSFORMED_IMAGES = True` in config.py.
"""
if len(train_loader) > 0:
for i in range(2):
images, targets = next(iter(train_loader))
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
for i in range(len(images)):
if len(targets[i]["boxes"]) == 0:
continue
boxes = targets[i]["boxes"].cpu().numpy().astype(np.int32)
labels = targets[i]["labels"].cpu().numpy().astype(np.int32)
sample = images[i].permute(1, 2, 0).cpu().numpy()
sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
for box_num, box in enumerate(boxes):
cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
cv2.putText(
sample,
CLASSES[labels[box_num]],
(box[0], box[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(0, 0, 255),
2,
)
cv2_imshow(sample)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
os.makedirs("outputs", exist_ok=True)
train_dataset = create_train_dataset(TRAIN_DIR, RESIZE_TO, CLASSES, get_train_transform)
valid_dataset = create_valid_dataset(VALID_DIR, RESIZE_TO, CLASSES, get_valid_transform)
train_loader = create_train_loader(train_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn)
valid_loader = create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn)
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(valid_dataset)}\n")
# Initialize the model and move to the computation device.
model = create_model(num_classes=NUM_CLASSES)
model = model.to(DEVICE)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True, weight_decay=0.0005)
scheduler = ReduceLROnPlateau(
optimizer,
mode="max", # we want to maximize mAP
factor=0.1, # reduce LR by this factor
patience=8, # wait 3 epochs with no improvement
threshold=0.005, # how much improvement is considered significant
cooldown=1,
)
# To monitor training loss
train_loss_hist = Averager()
# To store training loss and mAP values.
train_loss_list = []
map_50_list = []
map_list = []
# Mame to save the trained model with.
MODEL_NAME = "model"
# Whether to show transformed images from data loader or not.
if VISUALIZE_TRANSFORMED_IMAGES:
show_tranformed_image(train_loader, CLASSES, DEVICE)
# To save best model.
save_best_model = SaveBestModel()
metric = MeanAveragePrecision()
metric.warn_on_many_detections = False
# Training loop.
for epoch in range(NUM_EPOCHS):
print(f"\nEPOCH {epoch+1} of {NUM_EPOCHS}")
# Reset the training loss histories for the current epoch.
train_loss_hist.reset()
# Start timer and carry out training and validation.
start = time.time()
train_loss = train(train_loader, model, optimizer, train_loss_hist, DEVICE)
metric_summary = validate(valid_loader, model, DEVICE)
current_map_05_95 = float(metric_summary["map"])
current_map_05 = float(metric_summary["map_50"])
print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")
print(f"Epoch #{epoch+1} mAP: {metric_summary['map']:.3f}")
end = time.time()
print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch+1}")
train_loss_list.append(train_loss)
map_50_list.append(metric_summary["map_50"])
map_list.append(metric_summary["map"])
# save the best model till now.
save_best_model(model, float(metric_summary["map"]), epoch, "outputs")
# Save the current epoch model.
save_model(epoch, model, optimizer, OUT_DIR)
# Save loss plot.
save_loss_plot(OUT_DIR, train_loss_list)
# Save mAP plot.
save_mAP(OUT_DIR, map_50_list, map_list)
scheduler.step(current_map_05_95)
print("Current LR:", scheduler.get_last_lr())
Output hidden; open in https://colab.research.google.com to view.