In [1]:
import torch

BATCH_SIZE = 8  # Increase / decrease according to GPU memeory.
RESIZE_TO = 640  # Resize the image for training and transforms.
NUM_EPOCHS = 30  # Number of epochs to train for.
NUM_WORKERS = 4  # Number of parallel workers for data loading.

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training images and labels files directory.
TRAIN_DIR = "/content/drive/MyDrive/Colab Notebooks/data/train"
# Validation images and labels files directory.
VALID_DIR = "/content/drive/MyDrive/Colab Notebooks/data/valid"

# Classes: 0 index is reserved for background.
CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]


NUM_CLASSES = len(CLASSES)

# Whether to visualize images after crearing the data loaders.
VISUALIZE_TRANSFORMED_IMAGES = True

# Location to save model and plots.
OUT_DIR = "outputs"
In [2]:
!pip install config
Collecting config
  Downloading config-0.5.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading config-0.5.1-py2.py3-none-any.whl (20 kB)
Installing collected packages: config
Successfully installed config-0.5.1
In [3]:
def collate_fn(batch):
    """
    To handle the data loading as different images may have different
    numbers of objects, and to handle varying-size tensors as well.
    """
    return tuple(zip(*batch))


def get_train_transform():
    # We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=45),
            A.Blur(blur_limit=3, p=0.2),
            A.MotionBlur(blur_limit=3, p=0.1),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
            A.RandomScale(scale_limit=0.2, p=0.3),
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )


def get_valid_transform():
    return A.Compose(
        [
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )
In [4]:
import torch
import cv2
import numpy as np
import os
import glob

#from config import CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE
from torch.utils.data import Dataset, DataLoader
#from custom_utils import collate_fn, get_train_transform, get_valid_transform


class CustomDataset(Dataset):
    def __init__(self, dir_path, width, height, classes, transforms=None):
        """
        :param dir_path: Directory containing 'images/' and 'labels/' subfolders.
        :param width: Resized image width.
        :param height: Resized image height.
        :param classes: List of class names (or an indexing scheme).
        :param transforms: Albumentations transformations to apply.
        """
        self.transforms = transforms
        self.dir_path = dir_path
        self.image_dir = os.path.join(self.dir_path, "images")
        self.label_dir = os.path.join(self.dir_path, "labels")
        self.width = width
        self.height = height
        self.classes = classes

        # Gather all image paths
        self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
        self.all_image_paths = []
        for file_type in self.image_file_types:
            self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))

        # Sort for consistent ordering
        self.all_image_paths = sorted(self.all_image_paths)
        self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]

    def __len__(self):
        return len(self.all_image_paths)

    def __getitem__(self, idx):
        # 1) Read image
        image_name = self.all_image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_filename = os.path.splitext(image_name)[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_filename)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

        # 2) Resize image (to the model's expected size)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0  # Scale pixel values to [0, 1]

        # 3) Read bounding boxes (normalized) from .txt file
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                lines = f.readlines()

            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Format: class_id x_min y_min x_max y_max  (all in [0..1])
                parts = line.split()
                class_id = int(parts[0])  # e.g. 0, 1, 2, ...
                xmin = float(parts[1])
                ymin = float(parts[2])
                xmax = float(parts[3])
                ymax = float(parts[4])

                # Example: if you want class IDs to start at 1 for foreground
                # and background=0, do:
                label_idx = class_id + 1

                # Convert normalized coords to absolute (in resized space)
                x_min_final = xmin * self.width
                y_min_final = ymin * self.height
                x_max_final = xmax * self.width
                y_max_final = ymax * self.height

                # Ensure valid box
                if x_max_final <= x_min_final:
                    x_max_final = x_min_final + 1
                if y_max_final <= y_min_final:
                    y_max_final = y_min_final + 1

                # Clip if out of bounds
                x_min_final = max(0, min(x_min_final, self.width - 1))
                x_max_final = max(0, min(x_max_final, self.width))
                y_min_final = max(0, min(y_min_final, self.height - 1))
                y_max_final = max(0, min(y_max_final, self.height))

                boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
                labels.append(label_idx)

        # 4) Convert boxes & labels to Torch tensors
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        # 5) Prepare the target dict
        area = (
            (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            if len(boxes) > 0
            else torch.tensor([], dtype=torch.float32)
        )
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        image_id = torch.tensor([idx])

        target = {"boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd, "image_id": image_id}

        # 6) Albumentations transforms: pass Python lists, not Tensors
        if self.transforms:
            bboxes_list = boxes.cpu().numpy().tolist()  # shape: list of [xmin, ymin, xmax, ymax]
            labels_list = labels.cpu().numpy().tolist()  # shape: list of ints

            transformed = self.transforms(
                image=image_resized,
                bboxes=bboxes_list,
                labels=labels_list,
            )

            # Reassign the image
            image_resized = transformed["image"]

            # Convert bboxes back to Torch Tensors
            new_bboxes_list = transformed["bboxes"]  # list of [xmin, ymin, xmax, ymax]
            new_labels_list = transformed["labels"]  # list of int

            if len(new_bboxes_list) > 0:
                new_bboxes = torch.tensor(new_bboxes_list, dtype=torch.float32)
                new_labels = torch.tensor(new_labels_list, dtype=torch.int64)
            else:
                new_bboxes = torch.zeros((0, 4), dtype=torch.float32)
                new_labels = torch.zeros((0,), dtype=torch.int64)

            target["boxes"] = new_bboxes
            target["labels"] = new_labels

        return image_resized, target
In [5]:
# ---------------------------------------------------------
# Debug/demo if run directly
# ---------------------------------------------------------
if __name__ == "__main__":
    # Example usage with no transforms for debugging
    dataset = CustomDataset(dir_path=TRAIN_DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=None)
    print(f"Number of training images: {len(dataset)}")

    if len(dataset) > 0:
        from google.colab.patches import cv2_imshow # Import cv2_imshow

        def visualize_sample(image, target):
            """
            Visualize a single sample using OpenCV. Expects
            `image` as a NumPy array of shape (H, W, 3) in [0..1].
            """
            # Convert [0,1] float -> [0,255] uint8
            img = (image * 255).astype(np.uint8)
            # Convert RGB -> BGR
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

            boxes = target["boxes"].cpu().numpy().astype(np.int32)
            labels = target["labels"].cpu().numpy().astype(np.int32)

            for i, box in enumerate(boxes):
                x1, y1, x2, y2 = box
                class_idx = labels[i]

                # If your class_idx starts at 1 for "first class", ensure you handle that:
                # e.g. if CLASSES = ["background", "class1", "class2", ...]
                if 0 <= class_idx < len(CLASSES):
                    class_str = CLASSES[class_idx]
                else:
                    class_str = f"Label_{class_idx}"

                cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.putText(
                    img, class_str, (x1, max(y1 - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2
                )

            cv2_imshow(img) # Use cv2_imshow instead of cv2.imshow
            cv2.waitKey(0)

        # Visualize a few samples
        NUM_SAMPLES_TO_VISUALIZE = 10
        for i in range(min(NUM_SAMPLES_TO_VISUALIZE, len(dataset))): # Also adjust loop range
            image, target = dataset[i]  # No transforms in this example
            # `image` is shape (H, W, 3) in [0..1]
            print(f"Visualizing sample {i}, boxes: {target['boxes'].shape[0]}")
            visualize_sample(image, target)
        cv2.destroyAllWindows()
    else:
        print("Dataset is empty. Cannot visualize samples.")
Output hidden; open in https://colab.research.google.com to view.
In [ ]:
def create_train_loader(train_dataset, num_workers=0):
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return train_loader


def create_valid_loader(valid_dataset, num_workers=0):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return valid_loader
In [6]:
!pip install torchmetrics
Collecting torchmetrics
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.0.2)
Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (25.0)
Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.8.0+cu126)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (75.2.0)
Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.14.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.19.1)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (1.13.3)
Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2025.3.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2.27.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (1.11.1.6)
Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.4.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->torchmetrics) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->torchmetrics) (3.0.2)
Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 983.0/983.0 kB 40.4 MB/s eta 0:00:00
Downloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.1
In [7]:
from tqdm.auto import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

import torch
import matplotlib.pyplot as plt
import time
import os
In [8]:
plt.style.use("ggplot")

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
In [9]:
# Function for running training iterations.
def train(train_data_loader, model):
    print("Training")
    model.train()

    # initialize tqdm progress bar
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        train_loss_hist.send(loss_value)

        losses.backward()
        optimizer.step()

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return loss_value
In [10]:
from tqdm.auto import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

import torch
import matplotlib.pyplot as plt
import time
import os
import torchvision
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
import cv2
import numpy as np
import glob
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from google.colab.patches import cv2_imshow # Import cv2_imshow

# Define functions here
def collate_fn(batch):
    """
    To handle the data loading as different images may have different
    numbers of objects, and to handle varying-size tensors as well.
    """
    return tuple(zip(*batch))

def get_train_transform():
    # We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=45),
            A.Blur(blur_limit=3, p=0.2),
            A.MotionBlur(blur_limit=3, p=0.1),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
            A.RandomScale(scale_limit=0.2, p=0.3),
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )


def get_valid_transform():
    return A.Compose(
        [
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )

class CustomDataset(Dataset):
    def __init__(self, dir_path, width, height, classes, transforms=None):
        """
        :param dir_path: Directory containing 'images/' and 'labels/' subfolders.
        :param width: Resized image width.
        :param height: Resized image height.
        :param classes: List of class names (or an indexing scheme).
        :param transforms: Albumentations transformations to apply.
        """
        self.transforms = transforms
        self.dir_path = dir_path
        self.image_dir = os.path.join(self.dir_path, "images")
        self.label_dir = os.path.join(self.dir_path, "labels")
        self.width = width
        self.height = height
        self.classes = classes

        # Gather all image paths
        self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
        self.all_image_paths = []
        for file_type in self.image_file_types:
            self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))

        # Sort for consistent ordering
        self.all_image_paths = sorted(self.all_image_paths)
        self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]

    def __len__(self):
        return len(self.all_image_paths)

    def __getitem__(self, idx):
        # 1) Read image
        image_name = self.all_image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_filename = os.path.splitext(image_name)[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_filename)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

        # 2) Resize image (to the model's expected size)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0  # Scale pixel values to [0, 1]

        # 3) Read bounding boxes (normalized) from .txt file
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                lines = f.readlines()

            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Format: class_id x_min y_min x_max y_max  (all in [0..1])
                parts = line.split()
                class_id = int(parts[0])  # e.g. 0, 1, 2, ...
                # Example: if you want class IDs to start at 1 for foreground
                # and background=0, do:
                label_idx = class_id + 1

                xmin = float(parts[1])
                ymin = float(parts[2])
                xmax = float(parts[3])
                ymax = float(parts[4])

                # Convert normalized coords to absolute (in resized space)
                x_min_final = xmin * self.width
                y_min_final = ymin * self.height
                x_max_final = xmax * self.width
                y_max_final = ymax * self.height

                # Ensure valid box
                if x_max_final <= x_min_final:
                    x_max_final = x_min_final + 1
                if y_max_final <= y_min_final:
                    y_max_final = y_min_final + 1

                # Clip if out of bounds
                x_min_final = max(0, min(x_min_final, self.width - 1))
                x_max_final = max(0, min(x_max_final, self.width))
                y_min_final = max(0, min(y_min_final, self.height - 1))
                y_max_final = max(0, min(y_max_final, self.height))

                boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
                labels.append(label_idx)

        # 4) Convert boxes & labels to Torch tensors
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        # 5) Prepare the target dict
        area = (
            (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            if len(boxes) > 0
            else torch.tensor([], dtype=torch.float32)
        )
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        image_id = torch.tensor([idx])

        target = {"boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd, "image_id": image_id}

        # 6) Albumentations transforms: pass Python lists, not Tensors
        if self.transforms:
            bboxes_list = boxes.cpu().numpy().tolist()  # shape: list of [xmin, ymin, xmax, ymax]
            labels_list = labels.cpu().numpy().tolist()  # shape: list of ints

            transformed = self.transforms(
                image=image_resized,
                bboxes=bboxes_list,
                labels=labels_list,
            )

            # Reassign the image
            image_resized = transformed["image"]

            # Convert bboxes back to Torch Tensors
            new_bboxes_list = transformed["bboxes"]  # list of [xmin, ymin, xmax, ymax]
            new_labels_list = transformed["labels"]  # list of int

            if len(new_bboxes_list) > 0:
                new_bboxes = torch.tensor(new_bboxes_list, dtype=torch.float32)
                new_labels = torch.tensor(new_labels_list, dtype=torch.int64)
            else:
                new_bboxes = torch.zeros((0, 4), dtype=torch.float32)
                new_labels = torch.zeros((0,), dtype=torch.int64)

            target["boxes"] = new_bboxes
            target["labels"] = new_labels


        return image_resized, target

# ---------------------------------------------------------
# Create train/valid datasets and loaders
# ---------------------------------------------------------
def create_train_dataset(DIR, RESIZE_TO, CLASSES, get_train_transform):
    train_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_train_transform()
    )
    return train_dataset


def create_valid_dataset(DIR, RESIZE_TO, CLASSES, get_valid_transform):
    valid_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_valid_transform()
    )
    return valid_dataset


def create_train_loader(train_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn):
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return train_loader


def create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return valid_loader

class Averager:
    """
    A class to keep track of running average of values (e.g. training loss).
    """

    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0


class SaveBestModel:
    """
    Saves the model if the current epoch's validation mAP is higher
    than all previously observed values.
    """

    def __init__(self, best_valid_map=float(0)):
        self.best_valid_map = best_valid_map

    def __call__(
        self,
        model,
        current_valid_map,
        epoch,
        OUT_DIR,
    ):
        if current_valid_map > self.best_valid_map:
            self.best_valid_map = current_valid_map
            print(f"\nBEST VALIDATION mAP: {self.best_valid_map}")
            print(f"SAVING BEST MODEL FOR EPOCH: {epoch+1}\n")
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                },
                f"{OUT_DIR}/best_model.pth",
            )

def save_model(epoch, model, optimizer, OUT_DIR):
    """
    Save the trained model (state dict) and optimizer state to disk.
    """
    torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        f"{OUT_DIR}/last_model.pth",
    )

def save_loss_plot(OUT_DIR, train_loss_list, x_label="iterations", y_label="train loss", save_name="train_loss"):
    """
    Saves the training loss curve.
    """
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss_list, color="tab:blue")
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(f"{OUT_DIR}/{save_name}.png")
    # plt.close()
    print("SAVING PLOTS COMPLETE...")

def save_mAP(OUT_DIR, map_05, map):
    """
    Saves the mAP@0.5 and mAP@0.5:0.95 curves per epoch.
    """
    plt.figure(figsize=(10, 7))
    plt.plot(map_05, color="tab:orange", linestyle="-", label="mAP@0.5")
    plt.plot(map, color="tab:red", linestyle="-", label="mAP@0.5:0.95")
    plt.xlabel("Epochs")
    plt.ylabel("mAP")
    plt.legend()
    plt.savefig(f"{OUT_DIR}/map.png")
    # plt.close()
    print("SAVING mAP PLOTS COMPLETE...")

def create_model(num_classes=91):
    """
    Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
    Replaces the classification head for the required number of classes.
    """
    model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
    num_anchors = model.head.classification_head.num_anchors

    # Replace the classification head
    model.head.classification_head = RetinaNetClassificationHead(
        in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    return model

def train(train_data_loader, model, optimizer, train_loss_hist, DEVICE):
    print("Training")
    model.train()

    # initialize tqdm progress bar
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        train_loss_hist.send(loss_value)

        losses.backward()
        optimizer.step()

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return loss_value

def validate(valid_loader, model, DEVICE):
    """
    Function to perform validation on the validation dataset.
    Returns the mAP values.
    """
    print("Validating")
    # Initialize the metric.
    metric = MeanAveragePrecision()
    # Set the model to evaluation mode.
    model.eval()
    # Initialize tqdm progress bar.
    prog_bar = tqdm(valid_loader, total=len(valid_loader))

    with torch.no_grad():
        for i, data in enumerate(prog_bar):
            images, targets = data

            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

            outputs = model(images)

            # For torchmetrics, we need to format the predictions and targets.
            # Predictions should be a list of dicts, each dict containing:
            # boxes (FloatTensor[N, 4]), scores (FloatTensor[N]), labels (IntTensor[N])
            # Targets should also be a list of dicts, each dict containing:
            # boxes (FloatTensor[M, 4]), labels (IntTensor[M])

            formatted_preds = []
            formatted_targets = []

            for j in range(len(outputs)):
                formatted_preds.append({
                    "boxes": outputs[j]["boxes"],
                    "scores": outputs[j]["scores"],
                    "labels": outputs[j]["labels"],
                })
                formatted_targets.append({
                    "boxes": targets[j]["boxes"],
                    "labels": targets[j]["labels"],
                })

            metric.update(formatted_preds, formatted_targets)

    # Compute the metrics.
    metric_summary = metric.compute()
    return metric_summary

def show_tranformed_image(train_loader, CLASSES, DEVICE):
    """
    Visualize transformed images from the `train_loader` for debugging.
    Only runs if `VISUALIZE_TRANSFORMED_IMAGES = True` in config.py.
    """
    if len(train_loader) > 0:
        for i in range(2):
            images, targets = next(iter(train_loader))
            images = list(image.to(DEVICE) for image in images)

            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            for i in range(len(images)):
                if len(targets[i]["boxes"]) == 0:
                    continue
                boxes = targets[i]["boxes"].cpu().numpy().astype(np.int32)
                labels = targets[i]["labels"].cpu().numpy().astype(np.int32)
                sample = images[i].permute(1, 2, 0).cpu().numpy()
                sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)

                for box_num, box in enumerate(boxes):
                    cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
                    cv2.putText(
                        sample,
                        CLASSES[labels[box_num]],
                        (box[0], box[1] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1.0,
                        (0, 0, 255),
                        2,
                    )
                cv2_imshow(sample)
                cv2.waitKey(0)
                cv2.destroyAllWindows()


if __name__ == "__main__":
    os.makedirs("outputs", exist_ok=True)
    train_dataset = create_train_dataset(TRAIN_DIR, RESIZE_TO, CLASSES, get_train_transform)
    valid_dataset = create_valid_dataset(VALID_DIR, RESIZE_TO, CLASSES, get_valid_transform)
    train_loader = create_train_loader(train_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn)
    valid_loader = create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS, collate_fn)
    print(f"Number of training samples: {len(train_dataset)}")
    print(f"Number of validation samples: {len(valid_dataset)}\n")

    # Initialize the model and move to the computation device.
    model = create_model(num_classes=NUM_CLASSES)
    model = model.to(DEVICE)
    print(model)
    # Total parameters and trainable parameters.
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True, weight_decay=0.0005)
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",  # we want to maximize mAP
        factor=0.1,  # reduce LR by this factor
        patience=8,  # wait 3 epochs with no improvement
        threshold=0.005,  # how much improvement is considered significant
        cooldown=1,
    )

    # To monitor training loss
    train_loss_hist = Averager()
    # To store training loss and mAP values.
    train_loss_list = []
    map_50_list = []
    map_list = []

    # Mame to save the trained model with.
    MODEL_NAME = "model"

    # Whether to show transformed images from data loader or not.
    if VISUALIZE_TRANSFORMED_IMAGES:
        show_tranformed_image(train_loader, CLASSES, DEVICE)

    # To save best model.
    save_best_model = SaveBestModel()

    metric = MeanAveragePrecision()
    metric.warn_on_many_detections = False

    # Training loop.
    for epoch in range(NUM_EPOCHS):
        print(f"\nEPOCH {epoch+1} of {NUM_EPOCHS}")

        # Reset the training loss histories for the current epoch.
        train_loss_hist.reset()

        # Start timer and carry out training and validation.
        start = time.time()
        train_loss = train(train_loader, model, optimizer, train_loss_hist, DEVICE)
        metric_summary = validate(valid_loader, model, DEVICE)
        current_map_05_95 = float(metric_summary["map"])
        current_map_05 = float(metric_summary["map_50"])
        print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")
        print(f"Epoch #{epoch+1} mAP: {metric_summary['map']:.3f}")
        end = time.time()
        print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch+1}")

        train_loss_list.append(train_loss)
        map_50_list.append(metric_summary["map_50"])
        map_list.append(metric_summary["map"])

        # save the best model till now.
        save_best_model(model, float(metric_summary["map"]), epoch, "outputs")
        # Save the current epoch model.
        save_model(epoch, model, optimizer, OUT_DIR)

        # Save loss plot.
        save_loss_plot(OUT_DIR, train_loss_list)

        # Save mAP plot.
        save_mAP(OUT_DIR, map_50_list, map_list)
        scheduler.step(current_map_05_95)
        print("Current LR:", scheduler.get_last_lr())
Output hidden; open in https://colab.research.google.com to view.