Primero se carga la base de datos. Con esto se genera el archivo african-wildlife.yaml

In [1]:
yaml_content = """
path: african-wildlife
train: images/train
val: images/val
test: images/test

names:
  0: buffalo
  1: elephant
  2: rhino
  3: zebra

download: https://github.com/ultralytics/assets/releases/download/v0.0.0/african-wildlife.zip
"""

with open("african-wildlife.yaml", "w") as f:
    f.write(yaml_content)

Análisis Exploratorio del dataset African Wildlife de Ultralytics. Primero se cambia el entorno de ejecución (runtime type) a t4gpu. El proceso necesita instalar ultralytics. El símbolo ! le dice al entorno que ejecute un comando de sistema (shell / terminal), no un comando de Python. El paquete también se intala en el entorno virtual.

In [2]:
!nvidia-smi
Sat Aug 23 01:17:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   42C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
In [3]:
!pip install ultralytics
Collecting ultralytics
  Downloading ultralytics-8.3.184-py3-none-any.whl.metadata (37 kB)
Requirement already satisfied: numpy>=1.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.0.2)
Requirement already satisfied: matplotlib>=3.3.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (3.10.0)
Requirement already satisfied: opencv-python>=4.6.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (4.12.0.88)
Requirement already satisfied: pillow>=7.1.2 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (11.3.0)
Requirement already satisfied: pyyaml>=5.3.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (6.0.2)
Requirement already satisfied: requests>=2.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.32.4)
Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (1.16.1)
Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.8.0+cu126)
Requirement already satisfied: torchvision>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (0.23.0+cu126)
Requirement already satisfied: tqdm>=4.64.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (4.67.1)
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ultralytics) (5.9.5)
Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.12/dist-packages (from ultralytics) (9.0.0)
Requirement already satisfied: pandas>=1.1.4 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.2.2)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.16-py3-none-any.whl.metadata (14 kB)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (4.59.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (25.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (3.2.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.1.4->ultralytics) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.1.4->ultralytics) (2025.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2025.8.3)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.19.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (4.14.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.13.3)
Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2025.3.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2.27.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.11.1.6)
Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.4.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8.0->ultralytics) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8.0->ultralytics) (3.0.2)
Downloading ultralytics-8.3.184-py3-none-any.whl (1.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 24.1 MB/s eta 0:00:00
Downloading ultralytics_thop-2.0.16-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.184 ultralytics-thop-2.0.16

Modelo YOLO¶

El siguiente comando entrena un modelo YOLO con el dataset African Wildlife. El paquete ultralytics se instala con el comando yolo (para entrenar el modelo) detect train indica que vamos a hacer detección de objetos y entrenar al modelo El data set .yaml incluye datos de entrenamiento y validación, así como las clases (especies animales). yolov8s.pt es un modelo preentrenado de yolo, y es mejor que yolo11n.pt el modelo analiza los datos en 30 ciclos. Las imágeness se redimensionan a 640×640 píxeles Batch define cuántas imágenes se usan en cada iteración.

El resultado del modelo genera las carpetas african-wildlife->images->(test,train,val), labels->(test,train,val) runs->detect->train->weights sample_data-> y los archivos yolo11n.pt y yolo8s.pt

In [4]:
# Start training from a pretrained *.pt model
!yolo detect train data=african-wildlife.yaml model=yolov8s.pt epochs=30 imgsz=640 batch=8
Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt to 'yolov8s.pt': 100% 21.5M/21.5M [00:00<00:00, 51.8MB/s]
Ultralytics 8.3.184 🚀 Python-3.12.11 torch-2.8.0+cu126 CUDA:0 (Tesla T4, 15095MiB)
engine/trainer: agnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=8, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=african-wildlife.yaml, degrees=0.0, deterministic=True, device=None, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=30, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=yolov8s.pt, momentum=0.937, mosaic=1.0, multi_scale=False, name=train, nbs=64, nms=False, opset=None, optimize=False, optimizer=auto, overlap_mask=True, patience=100, perspective=0.0, plots=True, pose=12.0, pretrained=True, profile=False, project=None, rect=False, resume=False, retina_masks=False, save=True, save_conf=False, save_crop=False, save_dir=runs/detect/train, save_frames=False, save_json=False, save_period=-1, save_txt=False, scale=0.5, seed=0, shear=0.0, show=False, show_boxes=True, show_conf=True, show_labels=True, simplify=True, single_cls=False, source=None, split=val, stream_buffer=False, task=detect, time=None, tracker=botsort.yaml, translate=0.1, val=True, verbose=True, vid_stride=1, visualize=False, warmup_bias_lr=0.1, warmup_epochs=3.0, warmup_momentum=0.8, weight_decay=0.0005, workers=8, workspace=None

WARNING ⚠️ Dataset 'african-wildlife.yaml' images not found, missing path '/content/datasets/african-wildlife/images/val'
Downloading https://ultralytics.com/assets/african-wildlife.zip to '/content/datasets/african-wildlife.zip': 100% 100M/100M [00:02<00:00, 46.4MB/s] 
Unzipping /content/datasets/african-wildlife.zip to /content/datasets/african-wildlife...: 100% 3018/3018 [00:00<00:00, 3999.86file/s]
Dataset download success ✅ (3.7s), saved to /content/datasets

Downloading https://ultralytics.com/assets/Arial.ttf to '/root/.config/Ultralytics/Arial.ttf': 100% 755k/755k [00:00<00:00, 21.2MB/s]
Overriding model.yaml nc=80 with nc=4

                   from  n    params  module                                       arguments                     
  0                  -1  1       928  ultralytics.nn.modules.conv.Conv             [3, 32, 3, 2]                 
  1                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  2                  -1  1     29056  ultralytics.nn.modules.block.C2f             [64, 64, 1, True]             
  3                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  4                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]           
  5                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]              
  6                  -1  2    788480  ultralytics.nn.modules.block.C2f             [256, 256, 2, True]           
  7                  -1  1   1180672  ultralytics.nn.modules.conv.Conv             [256, 512, 3, 2]              
  8                  -1  1   1838080  ultralytics.nn.modules.block.C2f             [512, 512, 1, True]           
  9                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]                 
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 12                  -1  1    591360  ultralytics.nn.modules.block.C2f             [768, 256, 1]                 
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 15                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]                 
 16                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
 17            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 18                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]                 
 19                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 20             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 21                  -1  1   1969152  ultralytics.nn.modules.block.C2f             [768, 512, 1]                 
 22        [15, 18, 21]  1   2117596  ultralytics.nn.modules.head.Detect           [4, [128, 256, 512]]          
Model summary: 129 layers, 11,137,148 parameters, 11,137,132 gradients, 28.7 GFLOPs

Transferred 349/355 items from pretrained weights
Freezing layer 'model.22.dfl.conv.weight'
AMP: running Automatic Mixed Precision (AMP) checks...
Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt to 'yolo11n.pt': 100% 5.35M/5.35M [00:00<00:00, 65.4MB/s]
AMP: checks passed ✅
train: Fast image access ✅ (ping: 0.0±0.0 ms, read: 1628.1±698.6 MB/s, size: 54.6 KB)
train: Scanning /content/datasets/african-wildlife/labels/train... 1052 images, 0 backgrounds, 0 corrupt: 100% 1052/1052 [00:00<00:00, 2384.53it/s]
train: New cache created: /content/datasets/african-wildlife/labels/train.cache
albumentations: Blur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, method='weighted_average', num_output_channels=3), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))
val: Fast image access ✅ (ping: 0.0±0.0 ms, read: 451.1±225.8 MB/s, size: 42.2 KB)
val: Scanning /content/datasets/african-wildlife/labels/val... 225 images, 0 backgrounds, 0 corrupt: 100% 225/225 [00:00<00:00, 2045.87it/s]
val: New cache created: /content/datasets/african-wildlife/labels/val.cache
Plotting labels to runs/detect/train/labels.jpg... 
optimizer: 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
optimizer: AdamW(lr=0.00125, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 2 dataloader workers
Logging results to runs/detect/train
Starting training for 30 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       1/30      1.88G     0.8357      1.618       1.22          9        640: 100% 132/132 [00:26<00:00,  4.90it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:04<00:00,  3.50it/s]
                   all        225        379        0.8      0.751      0.829      0.632

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       2/30      2.63G     0.9346      1.184      1.264         16        640: 100% 132/132 [00:21<00:00,  6.04it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.37it/s]
                   all        225        379      0.655      0.538        0.6      0.368

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       3/30      2.65G      1.016      1.215      1.315         18        640: 100% 132/132 [00:21<00:00,  6.09it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  7.71it/s]
                   all        225        379      0.546      0.649      0.679      0.449

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       4/30       2.7G      1.049      1.266      1.356         12        640: 100% 132/132 [00:22<00:00,  5.99it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.61it/s]
                   all        225        379      0.695      0.646      0.737      0.478

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       5/30      2.72G      1.015      1.165      1.316         21        640: 100% 132/132 [00:26<00:00,  5.07it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.05it/s]
                   all        225        379      0.728      0.786      0.821      0.579

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       6/30      2.78G     0.9675      1.083      1.295         11        640: 100% 132/132 [00:21<00:00,  6.17it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.62it/s]
                   all        225        379      0.827      0.731       0.84      0.597

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       7/30       2.8G     0.9589      1.044      1.286         13        640: 100% 132/132 [00:21<00:00,  6.07it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.49it/s]
                   all        225        379      0.827      0.724      0.843      0.584

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       8/30      2.85G      0.918     0.9576       1.26         10        640: 100% 132/132 [00:22<00:00,  5.89it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.85it/s]
                   all        225        379        0.9      0.811       0.89      0.655

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       9/30      2.87G     0.8834      0.953      1.241         50        640: 100% 132/132 [00:22<00:00,  5.91it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.50it/s]
                   all        225        379      0.881      0.789      0.886       0.66

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      10/30      2.92G     0.8862     0.9173      1.232         27        640: 100% 132/132 [00:22<00:00,  5.88it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.46it/s]
                   all        225        379      0.863      0.778      0.888      0.663

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      11/30      2.95G     0.8614     0.8738      1.207         15        640: 100% 132/132 [00:22<00:00,  5.83it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.43it/s]
                   all        225        379      0.864      0.852      0.915       0.69

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      12/30      2.99G     0.8199     0.8105      1.186         18        640: 100% 132/132 [00:22<00:00,  5.81it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.75it/s]
                   all        225        379      0.844       0.85      0.903      0.697

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      13/30      3.02G      0.795     0.7827      1.186         15        640: 100% 132/132 [00:22<00:00,  5.79it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.18it/s]
                   all        225        379      0.897       0.83      0.914      0.713

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      14/30      3.07G     0.7838     0.7929      1.177         16        640: 100% 132/132 [00:22<00:00,  5.82it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.18it/s]
                   all        225        379      0.859       0.83      0.916      0.718

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      15/30      3.09G     0.7604     0.7281      1.152         24        640: 100% 132/132 [00:22<00:00,  5.82it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  7.46it/s]
                   all        225        379      0.896      0.844      0.915       0.71

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      16/30      3.14G     0.7542     0.7148      1.154         24        640: 100% 132/132 [00:22<00:00,  5.92it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.45it/s]
                   all        225        379      0.894       0.87      0.923      0.731

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      17/30      3.16G     0.7455      0.705      1.139         18        640: 100% 132/132 [00:21<00:00,  6.13it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.46it/s]
                   all        225        379      0.928      0.831      0.932      0.729

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      18/30      3.21G     0.7223     0.6931      1.129         12        640: 100% 132/132 [00:21<00:00,  6.09it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  7.27it/s]
                   all        225        379      0.941      0.854      0.934       0.75

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      19/30      3.23G     0.6957     0.6472      1.108         14        640: 100% 132/132 [00:21<00:00,  6.01it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.61it/s]
                   all        225        379      0.896      0.895      0.944      0.749

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      20/30      3.29G     0.6984     0.6313       1.11         15        640: 100% 132/132 [00:23<00:00,  5.69it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  7.60it/s]
                   all        225        379      0.933      0.868      0.941      0.756
Closing dataloader mosaic
albumentations: Blur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, method='weighted_average', num_output_channels=3), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      21/30      3.31G     0.6153     0.5096      1.051          6        640: 100% 132/132 [00:24<00:00,  5.39it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.64it/s]
                   all        225        379      0.922       0.87      0.942      0.758

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      22/30      3.36G     0.6005      0.489      1.044          6        640: 100% 132/132 [00:21<00:00,  6.04it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.66it/s]
                   all        225        379      0.956      0.864      0.945      0.772

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      23/30      3.38G      0.584     0.4515      1.039          5        640: 100% 132/132 [00:22<00:00,  5.99it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.52it/s]
                   all        225        379       0.93      0.895      0.955      0.787

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      24/30      3.43G     0.5627     0.4348      1.027          6        640: 100% 132/132 [00:21<00:00,  6.03it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.68it/s]
                   all        225        379      0.948      0.887      0.957      0.793

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      25/30      3.62G     0.5428     0.4133      1.002          8        640: 100% 132/132 [00:21<00:00,  6.03it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.72it/s]
                   all        225        379      0.934      0.891      0.952      0.794

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      26/30      3.67G     0.5261     0.3889     0.9929          9        640: 100% 132/132 [00:21<00:00,  6.11it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.99it/s]
                   all        225        379       0.92      0.906      0.955      0.795

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      27/30      3.69G     0.5112      0.383     0.9826          6        640: 100% 132/132 [00:20<00:00,  6.30it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  6.72it/s]
                   all        225        379      0.924        0.9      0.949        0.8

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      28/30      3.74G     0.4894      0.357     0.9697         10        640: 100% 132/132 [00:21<00:00,  6.26it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.65it/s]
                   all        225        379      0.939      0.897      0.948      0.799

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      29/30      3.76G     0.4821     0.3559     0.9595          5        640: 100% 132/132 [00:22<00:00,  5.81it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.80it/s]
                   all        225        379      0.943      0.895      0.954      0.806

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      30/30      3.81G     0.4708      0.335      0.959          5        640: 100% 132/132 [00:21<00:00,  6.04it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:01<00:00,  8.76it/s]
                   all        225        379      0.941      0.912      0.957      0.815

30 epochs completed in 0.208 hours.
Optimizer stripped from runs/detect/train/weights/last.pt, 22.5MB
Optimizer stripped from runs/detect/train/weights/best.pt, 22.5MB

Validating runs/detect/train/weights/best.pt...
Ultralytics 8.3.184 🚀 Python-3.12.11 torch-2.8.0+cu126 CUDA:0 (Tesla T4, 15095MiB)
Model summary (fused): 72 layers, 11,127,132 parameters, 0 gradients, 28.4 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% 15/15 [00:02<00:00,  5.87it/s]
                   all        225        379      0.941      0.912      0.957      0.816
               buffalo         62         89      0.924      0.921      0.957      0.826
              elephant         53         91      0.911      0.896      0.946      0.784
                 rhino         55         85      0.978      0.953      0.976      0.873
                 zebra         59        114      0.952      0.878      0.948      0.781
Speed: 0.2ms preprocess, 3.3ms inference, 0.0ms loss, 2.8ms postprocess per image
Results saved to runs/detect/train
💡 Learn more at https://docs.ultralytics.com/modes/train

Análisis exploratorio de datos de acuerdo a https://www.kaggle.com/code/faldoae/exploratory-data-analysis-eda-for-image-datasets En este sitio las imágenes se organizan en carpetas diferentes a las de yolo, por lo que se debe adaptar el código.

In [5]:
import os
import yaml
from PIL import Image
import pandas as pd

# Leer YAML
with open('/content/datasets/african-wildlife/african-wildlife.yaml') as f:
    data = yaml.safe_load(f)

# Carpetas base
base_img_dir = '/content/datasets/african-wildlife/images'
base_label_dir = '/content/datasets/african-wildlife/labels'

splits = ['train', 'val', 'test']
records = []

for split in splits:
    img_dir = os.path.join(base_img_dir, split)
    label_dir = os.path.join(base_label_dir, split)

    if not os.path.exists(img_dir):
        continue

    for img_file in os.listdir(img_dir):
        if img_file.lower().endswith(('.jpg', '.png', '.jpeg')):
            img_path = os.path.join(img_dir, img_file)
            img = Image.open(img_path)
            w, h = img.size

            label_path = os.path.join(label_dir, img_file.rsplit('.', 1)[0] + '.txt')
            if os.path.exists(label_path):
                with open(label_path, 'r') as lf:
                    lines = lf.readlines()
                if lines:
                    first_class = int(lines[0].split()[0])
                    label_name = data['names'][first_class]
                else:
                    label_name = 'no_label'
            else:
                label_name = 'no_label'

            records.append({
                'split': split,
                'filepath': img_path,
                'width': w,
                'height': h,
                'label': label_name
            })

df = pd.DataFrame(records)
print(df.head(), "\nTotal imágenes:", len(df))
   split                                           filepath  width  height  \
0  train  /content/datasets/african-wildlife/images/trai...    500     332   
1  train  /content/datasets/african-wildlife/images/trai...    900     600   
2  train  /content/datasets/african-wildlife/images/trai...   1000     750   
3  train  /content/datasets/african-wildlife/images/trai...    768     500   
4  train  /content/datasets/african-wildlife/images/trai...    640     480   

      label  
0     zebra  
1     rhino  
2   buffalo  
3  elephant  
4     zebra   
Total imágenes: 1504

Estadísticas generales.

In [6]:
print("Total imágenes:", len(df))
print("\nDistribución por split:")
print(df['split'].value_counts())

print("\nDistribución por clase:")
print(df['label'].value_counts())
Total imágenes: 1504

Distribución por split:
split
train    1052
test      227
val       225
Name: count, dtype: int64

Distribución por clase:
label
zebra       380
buffalo     376
elephant    376
rhino       372
Name: count, dtype: int64

Conteo por especie en cada conjunto (split).

In [7]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10,5))
sns.countplot(data=df, x='label', hue='split', palette='Set2')
plt.title("Número de imágenes por especie y split")
plt.xticks(rotation=45)
plt.show()
No description has been provided for this image

Distribución de tamaños de imagen

In [8]:
plt.figure(figsize=(10,5))
sns.histplot(data=df, x='width', hue='split', bins=30, element='step')
plt.title("Distribución de ancho de imágenes por split")
plt.show()

plt.figure(figsize=(10,5))
sns.histplot(data=df, x='height', hue='split', bins=30, element='step')
plt.title("Distribución de alto de imágenes por split")
plt.show()
No description has been provided for this image
No description has been provided for this image

Relación ancho/alto por especie.

In [ ]:
df['ratio'] = df['width'] / df['height']

plt.figure(figsize=(10,5))
sns.boxplot(data=df, x='label', y='ratio')
plt.title("Relación ancho/alto por especie")
plt.xticks(rotation=45)
plt.show()
No description has been provided for this image

Mosaico de imágenes por especie.

In [ ]:
from PIL import Image
import random

def show_grid(df_subset, nrows=3, ncols=3):
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
    for ax, (_, row) in zip(axes.flatten(), df_subset.iterrows()):
        img = Image.open(row['filepath'])
        ax.imshow(img)
        ax.set_title(row['label'])
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Mostrar 9 imágenes aleatorias por clase
for label in df['label'].unique():
    subset = df[df['label'] == label]
    if len(subset) >= 9:
        sample = subset.sample(9)
    else:
        sample = subset
    print(f"Clase: {label}")
    show_grid(sample)
Clase: buffalo
No description has been provided for this image
Clase: zebra
No description has been provided for this image
Clase: elephant
No description has been provided for this image
Clase: rhino
No description has been provided for this image

Distribución de clases.

In [ ]:
plt.figure(figsize=(8,8))
df['label'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, cmap='tab20')
plt.title("Proporción de clases en todo el dataset")
plt.ylabel("")
plt.show()
No description has been provided for this image

copiado de la ágina EDA

In [ ]:
#steps to convert image dataset to dataframe
#can use pathlib to search for files
#Here's an example of how to recursively search the current directory for files
#Then, you just need to reformat that into a dataframe. Here's how to do that

import pandas as pd
from pathlib import Path

data = '/content/datasets/african-wildlife/images/test'

paths = [path.parts[-2:] for path in
         Path(data).rglob('*.*')]                             #writing purpose ('*.*') so that all image formats can be retrieved
df = pd.DataFrame(data=paths, columns=['Class','Images'])     #create column names for dataframe
df = df.sort_values('Class',ascending=True)                   #sort class name
df.reset_index(drop=True, inplace=True)                       #sort index of each row
df                                                            #display dataframe
Out[ ]:
Class Images
0 test 2 (270).jpg
1 test 1 (30).jpg
2 test 2 (326).jpg
3 test 1 (356).jpg
4 test 3 (328).jpg
... ... ...
222 test 2 (341).jpg
223 test 2 (85).jpg
224 test 3 (123).jpg
225 test 1 (192).jpg
226 test 3 (320).jpg

227 rows × 2 columns

In [ ]:
print('Count the number of image datasets')
print("Image Count : {}".format(len(df.Images)))
print("Class Count : {} \n".format(len(df['Class'].value_counts())))
print('Count the number of images in each class')
print(df['Class'].value_counts())
Count the number of image datasets
Image Count : 227
Class Count : 1 

Count the number of images in each class
Class
test    227
Name: count, dtype: int64

Adaptando el código de histogramas RGB

Para este histograma se utiliza el objeto df, el dataframe que contiene la información de las imágenes. Sin embargo, al correr el cósigo tuve un error por haber cambiado su valor en una celda anterior, así que tuve que volver a definirlo con df = pd.DataFrame(records)

In [ ]:
print(df.columns)
print(df.head())
Index(['Class', 'Images'], dtype='object')
  Class       Images
0  test  4 (235).jpg
1  test  4 (252).jpg
2  test  3 (334).jpg
3  test  3 (314).jpg
4  test    3 (8).jpg
In [ ]:
for channel in ["R", "G", "B"]:
    avg_df = df.groupby("Class").agg({
        "Images": lambda paths: np.mean([calculate_channel_average(os.path.join(data, p), channel) for p in paths])
    }).reset_index()
    avg_df.rename(columns={"Images": f"{channel}_mean", "Class":"label"}, inplace=True)
    channel_avgs.append(avg_df)
In [ ]:
records.append({
    'split': split,
    'filepath': img_path,
    'width': w,
    'height': h,
    'label': label_name   # 👈 aquí
})
df = pd.DataFrame(records)

El siguiente código nos genera una tabla con la composición de colores.

In [ ]:
import numpy as np
import cv2

def calculate_channel_average(img_path, channel):
    """
    Calcular promedio de intensidades por canal (R, G o B).
    """
    channel_dict = {"R": 2, "G": 1, "B": 0}  # OpenCV usa BGR
    channel_idx = channel_dict[channel]

    img = cv2.imread(img_path)  # carga como BGR
    if img is None:
        return np.nan
    channel_intensities = img[:,:,channel_idx].flatten()
    return np.mean(channel_intensities)

# Calcular promedios por clase
channel_avgs = []
for channel in ["R", "G", "B"]:
    avg_df = df.groupby("label").agg({
        "filepath": lambda paths: np.mean([calculate_channel_average(p, channel) for p in paths])
    }).reset_index()
    avg_df.rename(columns={"filepath": f"{channel}_mean"}, inplace=True)
    channel_avgs.append(avg_df)

# Unir los tres dataframes
color_df = channel_avgs[0]
color_df["G_mean"] = channel_avgs[1]["G_mean"]
color_df["B_mean"] = channel_avgs[2]["B_mean"]

print(color_df)
      label      R_mean      G_mean      B_mean
0   buffalo  133.210377  127.855308  102.483883
1  elephant  129.002352  124.825104  106.761089
2     rhino  136.463383  128.335836  106.394742
3     zebra  132.278974  125.813713  100.741905

Histogramas de intensidad de colores por especie.

In [ ]:
def plot_channel_barplot(df, channel):
    title_dict = {"R": "Red", "G": "Green", "B": "Blue"}
    palette_dict = {"R": "Reds", "G": "Greens", "B": "Blues"}

    plt.figure(figsize=(8,4))
    values = df[f"{channel}_mean"].values
    pal = sns.color_palette(palette_dict[channel], len(values))
    rank = values.argsort().argsort()

    ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
    plt.ylabel("Average Intensity")
    plt.title(f"Average {title_dict[channel]} Channel Intensity per Class")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    plt.show()

plot_channel_barplot(color_df, "R")
plot_channel_barplot(color_df, "G")
plot_channel_barplot(color_df, "B")
/tmp/ipython-input-78763786.py:10: FutureWarning: 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:10: UserWarning: Numpy array is not a supported type for `palette`. Please convert your palette to a list. This will become an error in v0.14
  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:13: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
No description has been provided for this image
/tmp/ipython-input-78763786.py:10: FutureWarning: 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:10: UserWarning: Numpy array is not a supported type for `palette`. Please convert your palette to a list. This will become an error in v0.14
  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:13: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
No description has been provided for this image
/tmp/ipython-input-78763786.py:10: FutureWarning: 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:10: UserWarning: Numpy array is not a supported type for `palette`. Please convert your palette to a list. This will become an error in v0.14
  ax = sns.barplot(x=df["label"], y=values, palette=np.array(pal[::-1])[rank])
/tmp/ipython-input-78763786.py:13: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
No description has been provided for this image
In [ ]:
color_df_melt = color_df.melt(id_vars="label", value_vars=["R_mean", "G_mean", "B_mean"],
                              var_name="Channel", value_name="Intensity")

plt.figure(figsize=(8,5))
sns.barplot(data=color_df_melt, x="label", y="Intensity", hue="Channel",
            palette={"R_mean":"red", "G_mean":"green", "B_mean":"blue"})
plt.title("Average RGB Intensities per Class")
plt.xticks(rotation=45)
plt.show()
No description has been provided for this image

EDA de Github

FineTuning RetinaNet for Wildlife Detection with PyTorch¶

Al parecer los archivos descargados son parte de un entorno completo, y deben subirse en este punto la carpeta completa a colab. Veamos qué pasa. No supe si cargar el zip, terminé a subir los archivos manualmente.

El primer paso es instalar las dependencias de requirements.txt Me salió este error: WARNING: The following packages were previously imported in this runtime: [aiofiles,cv2,gradio] You must restart the runtime in order to use newly installed versions. me dió miedo y le dije que no.

In [10]:
%cd /content/project
!pip install -r requirements.txt
[Errno 2] No such file or directory: '/content/project'
/content
Collecting opencv-python==4.11.0.86 (from -r requirements.txt (line 1))
  Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting torch==2.6.0 (from -r requirements.txt (line 2))
  Downloading torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision==0.21.0 (from -r requirements.txt (line 3))
  Downloading torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio==2.6.0 (from -r requirements.txt (line 4))
  Downloading torchaudio-2.6.0-cp312-cp312-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting gradio==5.18.0 (from -r requirements.txt (line 5))
  Downloading gradio-5.18.0-py3-none-any.whl.metadata (16 kB)
Requirement already satisfied: numpy>=1.21.2 in /usr/local/lib/python3.12/dist-packages (from opencv-python==4.11.0.86->-r requirements.txt (line 1)) (2.0.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (3.19.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (4.14.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (2025.3.0)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparselt-cu12==0.6.2 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting nvidia-nccl-cu12==2.21.5 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting nvidia-nvtx-cu12==12.4.127 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting triton==3.2.0 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch==2.6.0->-r requirements.txt (line 2)) (75.2.0)
Collecting sympy==1.13.1 (from torch==2.6.0->-r requirements.txt (line 2))
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.12/dist-packages (from torchvision==0.21.0->-r requirements.txt (line 3)) (11.3.0)
Collecting aiofiles<24.0,>=22.0 (from gradio==5.18.0->-r requirements.txt (line 5))
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (4.10.0)
Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.116.1)
Requirement already satisfied: ffmpy in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.6.1)
Collecting gradio-client==1.7.2 (from gradio==5.18.0->-r requirements.txt (line 5))
  Downloading gradio_client-1.7.2-py3-none-any.whl.metadata (7.1 kB)
Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.28.1)
Requirement already satisfied: huggingface-hub>=0.28.1 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.34.4)
Collecting markupsafe~=2.0 (from gradio==5.18.0->-r requirements.txt (line 5))
  Downloading MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (3.11.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (25.0)
Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (2.2.2)
Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (2.11.7)
Requirement already satisfied: pydub in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.25.1)
Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.0.20)
Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (6.0.2)
Requirement already satisfied: ruff>=0.9.3 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.12.9)
Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.1.6)
Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (2.10.0)
Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.47.2)
Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.13.3)
Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.16.0)
Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.12/dist-packages (from gradio==5.18.0->-r requirements.txt (line 5)) (0.35.0)
Requirement already satisfied: websockets<16.0,>=10.0 in /usr/local/lib/python3.12/dist-packages (from gradio-client==1.7.2->gradio==5.18.0->-r requirements.txt (line 5)) (15.0.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy==1.13.1->torch==2.6.0->-r requirements.txt (line 2)) (1.3.0)
Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio==5.18.0->-r requirements.txt (line 5)) (3.10)
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio==5.18.0->-r requirements.txt (line 5)) (1.3.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx>=0.24.1->gradio==5.18.0->-r requirements.txt (line 5)) (2025.8.3)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx>=0.24.1->gradio==5.18.0->-r requirements.txt (line 5)) (1.0.9)
Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio==5.18.0->-r requirements.txt (line 5)) (0.16.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio==5.18.0->-r requirements.txt (line 5)) (2.32.4)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio==5.18.0->-r requirements.txt (line 5)) (4.67.1)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio==5.18.0->-r requirements.txt (line 5)) (1.1.7)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio==5.18.0->-r requirements.txt (line 5)) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio==5.18.0->-r requirements.txt (line 5)) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio==5.18.0->-r requirements.txt (line 5)) (2025.2)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio==5.18.0->-r requirements.txt (line 5)) (0.7.0)
Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio==5.18.0->-r requirements.txt (line 5)) (2.33.2)
Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio==5.18.0->-r requirements.txt (line 5)) (0.4.1)
Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (8.2.1)
Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (1.5.4)
Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (13.9.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio==5.18.0->-r requirements.txt (line 5)) (1.17.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (2.19.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.28.1->gradio==5.18.0->-r requirements.txt (line 5)) (3.4.3)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.28.1->gradio==5.18.0->-r requirements.txt (line 5)) (2.5.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio==5.18.0->-r requirements.txt (line 5)) (0.1.2)
Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.0/63.0 MB 11.7 MB/s eta 0:00:00
Downloading torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl (766.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 766.6/766.6 MB 1.9 MB/s eta 0:00:00
Downloading torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl (7.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 130.6 MB/s eta 0:00:00
Downloading torchaudio-2.6.0-cp312-cp312-manylinux1_x86_64.whl (3.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 98.0 MB/s eta 0:00:00
Downloading gradio-5.18.0-py3-none-any.whl (62.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.3/62.3 MB 8.4 MB/s eta 0:00:00
Downloading gradio_client-1.7.2-py3-none-any.whl (322 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 322.1/322.1 kB 31.4 MB/s eta 0:00:00
Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 1.5 MB/s eta 0:00:00
Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 93.5 MB/s eta 0:00:00
Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 84.2 MB/s eta 0:00:00
Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 63.6 MB/s eta 0:00:00
Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.9 MB/s eta 0:00:00
Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 5.4 MB/s eta 0:00:00
Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 15.1 MB/s eta 0:00:00
Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 7.8 MB/s eta 0:00:00
Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 5.7 MB/s eta 0:00:00
Downloading nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl (150.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.1/150.1 MB 6.7 MB/s eta 0:00:00
Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 188.7/188.7 MB 6.1 MB/s eta 0:00:00
Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 108.5 MB/s eta 0:00:00
Downloading nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 10.0 MB/s eta 0:00:00
Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 130.1 MB/s eta 0:00:00
Downloading triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 253.2/253.2 MB 3.1 MB/s eta 0:00:00
Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Downloading MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (28 kB)
Installing collected packages: triton, nvidia-cusparselt-cu12, sympy, opencv-python, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, markupsafe, aiofiles, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, gradio-client, torch, gradio, torchvision, torchaudio
  Attempting uninstall: triton
    Found existing installation: triton 3.4.0
    Uninstalling triton-3.4.0:
      Successfully uninstalled triton-3.4.0
  Attempting uninstall: nvidia-cusparselt-cu12
    Found existing installation: nvidia-cusparselt-cu12 0.7.1
    Uninstalling nvidia-cusparselt-cu12-0.7.1:
      Successfully uninstalled nvidia-cusparselt-cu12-0.7.1
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.3
    Uninstalling sympy-1.13.3:
      Successfully uninstalled sympy-1.13.3
  Attempting uninstall: opencv-python
    Found existing installation: opencv-python 4.12.0.88
    Uninstalling opencv-python-4.12.0.88:
      Successfully uninstalled opencv-python-4.12.0.88
  Attempting uninstall: nvidia-nvtx-cu12
    Found existing installation: nvidia-nvtx-cu12 12.6.77
    Uninstalling nvidia-nvtx-cu12-12.6.77:
      Successfully uninstalled nvidia-nvtx-cu12-12.6.77
  Attempting uninstall: nvidia-nvjitlink-cu12
    Found existing installation: nvidia-nvjitlink-cu12 12.6.85
    Uninstalling nvidia-nvjitlink-cu12-12.6.85:
      Successfully uninstalled nvidia-nvjitlink-cu12-12.6.85
  Attempting uninstall: nvidia-nccl-cu12
    Found existing installation: nvidia-nccl-cu12 2.27.3
    Uninstalling nvidia-nccl-cu12-2.27.3:
      Successfully uninstalled nvidia-nccl-cu12-2.27.3
  Attempting uninstall: nvidia-curand-cu12
    Found existing installation: nvidia-curand-cu12 10.3.7.77
    Uninstalling nvidia-curand-cu12-10.3.7.77:
      Successfully uninstalled nvidia-curand-cu12-10.3.7.77
  Attempting uninstall: nvidia-cufft-cu12
    Found existing installation: nvidia-cufft-cu12 11.3.0.4
    Uninstalling nvidia-cufft-cu12-11.3.0.4:
      Successfully uninstalled nvidia-cufft-cu12-11.3.0.4
  Attempting uninstall: nvidia-cuda-runtime-cu12
    Found existing installation: nvidia-cuda-runtime-cu12 12.6.77
    Uninstalling nvidia-cuda-runtime-cu12-12.6.77:
      Successfully uninstalled nvidia-cuda-runtime-cu12-12.6.77
  Attempting uninstall: nvidia-cuda-nvrtc-cu12
    Found existing installation: nvidia-cuda-nvrtc-cu12 12.6.77
    Uninstalling nvidia-cuda-nvrtc-cu12-12.6.77:
      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.6.77
  Attempting uninstall: nvidia-cuda-cupti-cu12
    Found existing installation: nvidia-cuda-cupti-cu12 12.6.80
    Uninstalling nvidia-cuda-cupti-cu12-12.6.80:
      Successfully uninstalled nvidia-cuda-cupti-cu12-12.6.80
  Attempting uninstall: nvidia-cublas-cu12
    Found existing installation: nvidia-cublas-cu12 12.6.4.1
    Uninstalling nvidia-cublas-cu12-12.6.4.1:
      Successfully uninstalled nvidia-cublas-cu12-12.6.4.1
  Attempting uninstall: markupsafe
    Found existing installation: MarkupSafe 3.0.2
    Uninstalling MarkupSafe-3.0.2:
      Successfully uninstalled MarkupSafe-3.0.2
  Attempting uninstall: aiofiles
    Found existing installation: aiofiles 24.1.0
    Uninstalling aiofiles-24.1.0:
      Successfully uninstalled aiofiles-24.1.0
  Attempting uninstall: nvidia-cusparse-cu12
    Found existing installation: nvidia-cusparse-cu12 12.5.4.2
    Uninstalling nvidia-cusparse-cu12-12.5.4.2:
      Successfully uninstalled nvidia-cusparse-cu12-12.5.4.2
  Attempting uninstall: nvidia-cudnn-cu12
    Found existing installation: nvidia-cudnn-cu12 9.10.2.21
    Uninstalling nvidia-cudnn-cu12-9.10.2.21:
      Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21
  Attempting uninstall: nvidia-cusolver-cu12
    Found existing installation: nvidia-cusolver-cu12 11.7.1.2
    Uninstalling nvidia-cusolver-cu12-11.7.1.2:
      Successfully uninstalled nvidia-cusolver-cu12-11.7.1.2
  Attempting uninstall: gradio-client
    Found existing installation: gradio_client 1.11.1
    Uninstalling gradio_client-1.11.1:
      Successfully uninstalled gradio_client-1.11.1
  Attempting uninstall: torch
    Found existing installation: torch 2.8.0+cu126
    Uninstalling torch-2.8.0+cu126:
      Successfully uninstalled torch-2.8.0+cu126
  Attempting uninstall: gradio
    Found existing installation: gradio 5.42.0
    Uninstalling gradio-5.42.0:
      Successfully uninstalled gradio-5.42.0
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.23.0+cu126
    Uninstalling torchvision-0.23.0+cu126:
      Successfully uninstalled torchvision-0.23.0+cu126
  Attempting uninstall: torchaudio
    Found existing installation: torchaudio 2.8.0+cu126
    Uninstalling torchaudio-2.8.0+cu126:
      Successfully uninstalled torchaudio-2.8.0+cu126
Successfully installed aiofiles-23.2.1 gradio-5.18.0 gradio-client-1.7.2 markupsafe-2.1.5 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-cusparselt-cu12-0.6.2 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 opencv-python-4.11.0.86 sympy-1.13.1 torch-2.6.0 torchaudio-2.6.0 torchvision-0.21.0 triton-3.2.0
In [11]:
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES
from datasets import CustomDataset
from custom_utils import collate_fn

Los archivos .py cargados en el proyecto se ejecutan con %cd /content/ !python custom_utils.py Tuve que ajustar el primer archivo config.py para ajustarlo a los directorios donde guardé los datos. Al ejecutar train.py tuve tantos errores que terminé cambiando config.py muchas veces. También hubo que cambiar la ruta de los datos train y val.

In [12]:
%%writefile /content/config.py
# config.py
import torch

BATCH_SIZE = 8  # Increase / decrease according to GPU memeory.
RESIZE_TO = 640  # Resize the image for training and transforms.
NUM_EPOCHS = 60  # Number of epochs to train for.
NUM_WORKERS = 4  # Number of parallel workers for data loading.

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training images and labels files directory.
TRAIN_DIR = "/content/datasets/african-wildlife/train"
# Validation images and labels files directory.
VALID_DIR = "/content/datasets/african-wildlife/val"

# Classes: 0 index is reserved for background.
CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]


NUM_CLASSES = len(CLASSES)

# Whether to visualize images after crearing the data loaders.
VISUALIZE_TRANSFORMED_IMAGES = True

# Location to save model and plots.
OUT_DIR = "outputs"
Overwriting /content/config.py

El comando %cd /content/ !python custom_utils.py ejecuta los archivos .py, me dí cuenta que después ya no era necesario agregar %cd /content/ Al final este tampoco era necesario correrlo

In [ ]:
%cd /content/
!python custom_utils.py
/content

Al parecer no se ejecutan cada unos de los .py, porque algunos contienen instrucciones del código que se usarán en los .py que sí se ejecutan (Estos cargarán la información de los restantes), aunque el tutorial no dice qué archivos sí se ejecutan ¬¬.

In [ ]:
%cd /content/
!python model.py
/content
Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth" to /root/.cache/torch/hub/checkpoints/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth
100% 146M/146M [00:00<00:00, 185MB/s]
RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-2): 3 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelP6P7(
        (p6): Conv2d(2048, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
    )
  )
  (anchor_generator): AnchorGenerator()
  (head): RetinaNetHead(
    (classification_head): RetinaNetClassificationHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (regression_head): RetinaNetRegressionHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (bbox_reg): Conv2d(256, 36, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
)
36,414,865 total parameters.
36,189,521 training parameters.

Aquí necesito instalar torchmetrics para que no salga error.

In [13]:
!pip install torchmetrics
Collecting torchmetrics
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.0.2)
Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (25.0)
Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics) (2.6.0)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (75.2.0)
Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.14.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.19.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2025.3.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.3.1.170)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (0.6.2)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (12.4.127)
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (3.2.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->torchmetrics) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy==1.13.1->torch>=2.0.0->torchmetrics) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->torchmetrics) (2.1.5)
Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 983.0/983.0 kB 25.2 MB/s eta 0:00:00
Downloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.1

Train.py no encuentra las imágenes porque están en un orden diferente de carpetas. Hay que modificar datasets.py para que las busque bien.

In [14]:
%%writefile /content/datasets.py
# datasets.py
import torch
import cv2
import numpy as np
import os
import glob

from config import CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE
from torch.utils.data import Dataset, DataLoader
from custom_utils import collate_fn, get_train_transform, get_valid_transform


class CustomDataset(Dataset):
    def __init__(self, dir_path, width, height, classes, transforms=None):
        """
        :param dir_path: Directory containing 'images/' and 'labels/' subfolders.
        :param width: Resized image width.
        :param height: Resized image height.
        :param classes: List of class names (or an indexing scheme).
        :param transforms: Albumentations transformations to apply.
        """
        self.transforms = transforms
        self.dir_path = dir_path
        self.image_dir = os.path.join(self.dir_path, "images")
        self.label_dir = os.path.join(self.dir_path, "labels")
        self.width = width
        self.height = height
        self.classes = classes

        # Gather all image paths
        self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
        self.all_image_paths = []
        for file_type in self.image_file_types:
            self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))

        # Sort for consistent ordering
        self.all_image_paths = sorted(self.all_image_paths)
        self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]

    def __len__(self):
        return len(self.all_image_paths)

    def __getitem__(self, idx):
        # 1) Read image
        image_name = self.all_image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_filename = os.path.splitext(image_name)[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_filename)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

        # 2) Resize image (to the model's expected size)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0  # Scale pixel values to [0, 1]

        # 3) Read bounding boxes (normalized) from .txt file
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                lines = f.readlines()

            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Format: class_id x_min y_min x_max y_max  (all in [0..1])
                parts = line.split()
                class_id = int(parts[0])  # e.g. 0, 1, 2, ...
                xmin = float(parts[1])
                ymin = float(parts[2])
                xmax = float(parts[3])
                ymax = float(parts[4])

                # Example: if you want class IDs to start at 1 for foreground
                # and background=0, do:
                label_idx = class_id + 1

                # Convert normalized coords to absolute (in resized space)
                x_min_final = xmin * self.width
                y_min_final = ymin * self.height
                x_max_final = xmax * self.width
                y_max_final = ymax * self.height

                # Ensure valid box
                if x_max_final <= x_min_final:
                    x_max_final = x_min_final + 1
                if y_max_final <= y_min_final:
                    y_max_final = y_min_final + 1

                # Clip if out of bounds
                x_min_final = max(0, min(x_min_final, self.width - 1))
                x_max_final = max(0, min(x_max_final, self.width))
                y_min_final = max(0, min(y_min_final, self.height - 1))
                y_max_final = max(0, min(y_max_final, self.height))

                boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
                labels.append(label_idx)

        # 4) Convert boxes & labels to Torch tensors
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        # 5) Prepare the target dict
        area = (
            (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            if len(boxes) > 0
            else torch.tensor([], dtype=torch.float32)
        )
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        image_id = torch.tensor([idx])

        target = {"boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd, "image_id": image_id}

        # 6) Albumentations transforms: pass Python lists, not Tensors
        if self.transforms:
            bboxes_list = boxes.cpu().numpy().tolist()  # shape: list of [xmin, ymin, xmax, ymax]
            labels_list = labels.cpu().numpy().tolist()  # shape: list of ints

            transformed = self.transforms(
                image=image_resized,
                bboxes=bboxes_list,
                labels=labels_list,
            )

            # Reassign the image
            image_resized = transformed["image"]

            # Convert bboxes back to Torch Tensors
            new_bboxes_list = transformed["bboxes"]  # list of [xmin, ymin, xmax, ymax]
            new_labels_list = transformed["labels"]  # list of int

            if len(new_bboxes_list) > 0:
                new_bboxes = torch.tensor(new_bboxes_list, dtype=torch.float32)
                new_labels = torch.tensor(new_labels_list, dtype=torch.int64)
            else:
                new_bboxes = torch.zeros((0, 4), dtype=torch.float32)
                new_labels = torch.zeros((0,), dtype=torch.int64)

            target["boxes"] = new_bboxes
            target["labels"] = new_labels

        return image_resized, target


# ---------------------------------------------------------
# Create train/valid datasets and loaders
# ---------------------------------------------------------
def create_train_dataset(DIR):
    train_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_train_transform()
    )
    return train_dataset


def create_valid_dataset(DIR):
    valid_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_valid_transform()
    )
    return valid_dataset


def create_train_loader(train_dataset, num_workers=0):
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return train_loader


def create_valid_loader(valid_dataset, num_workers=0):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return valid_loader


# ---------------------------------------------------------
# Debug/demo if run directly
# ---------------------------------------------------------
if __name__ == "__main__":
    # Example usage with no transforms for debugging
    dataset = CustomDataset(dir_path=TRAIN_DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=None)
    print(f"Number of training images: {len(dataset)}")

    def visualize_sample(image, target):
        """
        Visualize a single sample using OpenCV. Expects
        `image` as a NumPy array of shape (H, W, 3) in [0..1].
        """
        # Convert [0,1] float -> [0,255] uint8
        img = (image * 255).astype(np.uint8)
        # Convert RGB -> BGR
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        boxes = target["boxes"].cpu().numpy().astype(np.int32)
        labels = target["labels"].cpu().numpy().astype(np.int32)

        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = box
            class_idx = labels[i]

            # If your class_idx starts at 1 for "first class", ensure you handle that:
            # e.g. if CLASSES = ["background", "class1", "class2", ...]
            if 0 <= class_idx < len(CLASSES):
                class_str = CLASSES[class_idx]
            else:
                class_str = f"Label_{class_idx}"

            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(img, class_str, (x1, max(y1 - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

        cv2.imshow("Sample", img)
        cv2.waitKey(0)

    # Visualize a few samples
    NUM_SAMPLES_TO_VISUALIZE = 10
    for i in range(NUM_SAMPLES_TO_VISUALIZE):
        image, target = dataset[i]  # No transforms in this example
        # `image` is shape (H, W, 3) in [0..1]
        print(f"Visualizing sample {i}, boxes: {target['boxes'].shape[0]}")
        visualize_sample(image, target)
    cv2.destroyAllWindows()
Overwriting /content/datasets.py

Después de tanto intentar ajustar las rutas, me rendí y voy a tratar de acomodar todo en las carpetas como espera el código.

In [15]:
# Crear carpetas
!mkdir -p /content/datasets/african-wildlife/train/images
!mkdir -p /content/datasets/african-wildlife/train/labels
!mkdir -p /content/datasets/african-wildlife/val/images
!mkdir -p /content/datasets/african-wildlife/val/labels
!mkdir -p /content/datasets/african-wildlife/test/images
!mkdir -p /content/datasets/african-wildlife/test/labels

# Mover imágenes
!cp /content/datasets/african-wildlife/images/train/* /content/datasets/african-wildlife/train/images/
!cp /content/datasets/african-wildlife/images/val/* /content/datasets/african-wildlife/val/images/
!cp /content/datasets/african-wildlife/images/test/* /content/datasets/african-wildlife/test/images/

# Mover labels
!cp /content/datasets/african-wildlife/labels/train/* /content/datasets/african-wildlife/train/labels/
!cp /content/datasets/african-wildlife/labels/val/* /content/datasets/african-wildlife/val/labels/
!cp /content/datasets/african-wildlife/labels/test/* /content/datasets/african-wildlife/test/labels/

Lo que hice después de 1 trillón de intentos: cambiar las imágenes y etiquetas de carpeta con el código de arriba, luego cambiar la ruta en la que busca los datos en config.py con este ajuste:

In [16]:
%cd /content/
!python train.py
/content
/usr/local/lib/python3.12/dist-packages/albumentations/core/composition.py:331: UserWarning: Got processor for bboxes, but no transform to process it.
  self._set_keys()
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:624: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Number of training samples: 1052
Number of validation samples: 225

Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth" to /root/.cache/torch/hub/checkpoints/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth
100% 146M/146M [00:00<00:00, 182MB/s]
RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-2): 3 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelP6P7(
        (p6): Conv2d(2048, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
    )
  )
  (anchor_generator): AnchorGenerator()
  (head): RetinaNetHead(
    (classification_head): RetinaNetClassificationHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (regression_head): RetinaNetRegressionHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (bbox_reg): Conv2d(256, 36, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
)
36,414,865 total parameters.
36,189,521 training parameters.
qt.qpa.xcb: could not connect to display 
qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "/usr/local/lib/python3.12/dist-packages/cv2/qt/plugins" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: xcb.

Ahora pruebo a ejecutar app.py

In [17]:
!pip install gradio
Requirement already satisfied: gradio in /usr/local/lib/python3.12/dist-packages (5.18.0)
Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (23.2.1)
Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (4.10.0)
Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.116.1)
Requirement already satisfied: ffmpy in /usr/local/lib/python3.12/dist-packages (from gradio) (0.6.1)
Requirement already satisfied: gradio-client==1.7.2 in /usr/local/lib/python3.12/dist-packages (from gradio) (1.7.2)
Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.28.1)
Requirement already satisfied: huggingface-hub>=0.28.1 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.34.4)
Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (3.1.6)
Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.1.5)
Requirement already satisfied: numpy<3.0,>=1.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.0.2)
Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (3.11.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from gradio) (25.0)
Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.2.2)
Requirement already satisfied: pillow<12.0,>=8.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (11.3.0)
Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.11.7)
Requirement already satisfied: pydub in /usr/local/lib/python3.12/dist-packages (from gradio) (0.25.1)
Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.0.20)
Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (6.0.2)
Requirement already satisfied: ruff>=0.9.3 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.12.9)
Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.1.6)
Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.10.0)
Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.47.2)
Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.13.3)
Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.16.0)
Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (4.14.1)
Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.35.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from gradio-client==1.7.2->gradio) (2025.3.0)
Requirement already satisfied: websockets<16.0,>=10.0 in /usr/local/lib/python3.12/dist-packages (from gradio-client==1.7.2->gradio) (15.0.1)
Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio) (3.10)
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx>=0.24.1->gradio) (2025.8.3)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx>=0.24.1->gradio) (1.0.9)
Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.16.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio) (3.19.1)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio) (2.32.4)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio) (4.67.1)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.28.1->gradio) (1.1.7)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio) (0.7.0)
Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio) (2.33.2)
Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->gradio) (0.4.1)
Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (8.2.1)
Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (1.5.4)
Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (13.9.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (3.4.3)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (2.5.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
In [18]:
!mkdir -p outputs
!cp best_model_79.pth outputs/

No sirve app.py, intento cambiar el código para que permita abrir la interfaz desde colab. También hay boleanos que hacen que se quede en un loop.

In [19]:
%%writefile /content/app.py
# app.py
import os
import cv2
import time
import torch
import gradio as gr
import numpy as np

# Make sure these are your local imports from your project.
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES

# ----------------------------------------------------------------
# GLOBAL SETUP
# ----------------------------------------------------------------
# Create the model and load the best weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()

# Create a colors array for each class index.
# (length matches len(CLASSES), including background if you wish).
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))

# COLORS = [
#     (255, 255, 0),  # Cyan - background
#     (50, 0, 255),  # Red - buffalo
#     (147, 20, 255),  # Pink - elephant
#     (0, 255, 0),  # Green - rhino
#     (238, 130, 238),  # Violet - zebra
# ]


# ----------------------------------------------------------------
# HELPER FUNCTIONS
# ----------------------------------------------------------------
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Runs inference on a single image (OpenCV BGR or NumPy array).
    - resize_dim: if not None, we resize to (resize_dim, resize_dim)
    - threshold: detection confidence threshold
    Returns: processed image with bounding boxes drawn.
    """
    image = orig_image.copy()
    # Optionally resize for inference.
    if resize_dim is not None:
        image = cv2.resize(image, (resize_dim, resize_dim))

    # Convert BGR to RGB, normalize [0..1]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    # Move channels to front (C,H,W)
    image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
    start_time = time.time()
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
    end_time = time.time()
    # Get the current fps.
    fps = 1 / (end_time - start_time)
    fps_text = f"FPS: {fps:.2f}"
    # Move outputs to CPU numpy
    outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
    boxes = outputs[0]["boxes"].numpy()
    scores = outputs[0]["scores"].numpy()
    labels = outputs[0]["labels"].numpy().astype(int)

    # Filter out boxes with low confidence
    valid_idx = np.where(scores >= threshold)[0]
    boxes = boxes[valid_idx].astype(int)
    labels = labels[valid_idx]

    # If we resized for inference, rescale boxes back to orig_image size
    if resize_dim is not None:
        h_orig, w_orig = orig_image.shape[:2]
        h_new, w_new = resize_dim, resize_dim
        # scale boxes
        boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
        boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig

    # Draw bounding boxes
    for box, label_idx in zip(boxes, labels):
        class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
        color = COLORS[label_idx % len(COLORS)][::-1]  # BGR color
        cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
        cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3)
        cv2.putText(
            orig_image,
            fps_text,
            (int((w_orig / 2) - 50), 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (0, 255, 0),
            2,
            cv2.LINE_AA,
        )
    return orig_image, fps


def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Same as inference_on_image but for a single video frame.
    Returns the processed frame with bounding boxes.
    """
    return inference_on_image(frame, resize_dim, threshold)


# ----------------------------------------------------------------
# GRADIO FUNCTIONS
# ----------------------------------------------------------------


def img_inf(image_path, resize_dim, threshold):
    """
    Gradio function for image inference.
    :param image_path: File path from Gradio (uploaded image).
    :param model_name: Selected model from Radio (not used if only one model).
    Returns: A NumPy image array with bounding boxes.
    """
    if image_path is None:
        return None  # No image provided
    orig_image = cv2.imread(image_path)  # BGR
    if orig_image is None:
        return None  # Error reading image

    result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
    # Return the image in RGB for Gradio's display
    result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
    return result_image_rgb


def vid_inf(video_path, resize_dim, threshold):
    """
    Gradio function for video inference.
    Processes each frame, draws bounding boxes, and writes to an output video.
    Returns: (last_processed_frame, output_video_file_path)
    """
    if video_path is None:
        return None, None  # No video provided

    # Prepare input capture
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None

    # Create an output file path
    os.makedirs("inference_outputs/videos", exist_ok=True)
    out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
    # out_video_path = "video_output.mp4"

    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # or 'XVID'

    # If FPS is 0 (some weird container), default to something
    if fps <= 0:
        fps = 20.0

    out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))

    frame_count = 0
    total_fps = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Inference on frame
        processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
        total_fps += frame_fps
        frame_count += 1

        # Write the processed frame
        out_writer.write(processed_frame)
        yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None

    avg_fps = total_fps / frame_count

    cap.release()
    out_writer.release()
    print(f"Average FPS: {avg_fps:.3f}")
    yield None, out_video_path


# ----------------------------------------------------------------
# BUILD THE GRADIO INTERFACES
# ----------------------------------------------------------------

# For demonstration, we define two possible model radio choices.
# You can ignore or expand this if you only use RetinaNet.
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.Image(type="numpy", label="Output Image")

interface_image = gr.Interface(
    fn=img_inf,
    inputs=[inputs_image, resize_dim, threshold],
    outputs=outputs_image,
    title="Image Inference",
    description="Upload your photo, select a model, and see the results!",
    examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
    cache_examples=False,
)

resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
input_video = gr.Video(label="Input Video")

# Output is a pair: (last_processed_frame, output_video_path)
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
output_video_file = gr.Video(format="mp4", label="Output Video")

interface_video = gr.Interface(
    fn=vid_inf,
    inputs=[input_video, resize_dim, threshold],
    outputs=[output_frame, output_video_file],
    title="Video Inference",
    description="Upload your video and see the processed output!",
    examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
    cache_examples=False,
)

# Combine them in a Tabbed Interface
# Lanzamiento de la app Gradio
# ----------------------------------------------------------------
demo = gr.TabbedInterface(
    [interface_image, interface_video],
    tab_names=["Image", "Video"],
    title="FineTuning RetinaNet for Wildlife Animal Detection",
    theme="gstaff/xkcd",
)

# Usar .queue() para procesar video de manera eficiente y lanzar en Colab
demo.queue().launch(share=True, inbrowser=True)
Overwriting /content/app.py

Aún hay errores con boleanos, se corrigió actualizando Gradio.

In [20]:
pip install --upgrade gradio gradio-client
Requirement already satisfied: gradio in /usr/local/lib/python3.12/dist-packages (5.18.0)
Collecting gradio
  Downloading gradio-5.43.1-py3-none-any.whl.metadata (16 kB)
Requirement already satisfied: gradio-client in /usr/local/lib/python3.12/dist-packages (1.7.2)
Collecting gradio-client
  Downloading gradio_client-1.12.1-py3-none-any.whl.metadata (7.1 kB)
Requirement already satisfied: aiofiles<25.0,>=22.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (23.2.1)
Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (4.10.0)
Requirement already satisfied: brotli>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (1.1.0)
Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.116.1)
Requirement already satisfied: ffmpy in /usr/local/lib/python3.12/dist-packages (from gradio) (0.6.1)
Requirement already satisfied: groovy~=0.1 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.1.2)
Requirement already satisfied: httpx<1.0,>=0.24.1 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.28.1)
Requirement already satisfied: huggingface-hub<1.0,>=0.33.5 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.34.4)
Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (3.1.6)
Requirement already satisfied: markupsafe<4.0,>=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.1.5)
Requirement already satisfied: numpy<3.0,>=1.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.0.2)
Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (3.11.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from gradio) (25.0)
Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.2.2)
Requirement already satisfied: pillow<12.0,>=8.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (11.3.0)
Requirement already satisfied: pydantic<2.12,>=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.11.7)
Requirement already satisfied: pydub in /usr/local/lib/python3.12/dist-packages (from gradio) (0.25.1)
Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.0.20)
Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (6.0.2)
Requirement already satisfied: ruff>=0.9.3 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.12.9)
Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.1.6)
Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (2.10.0)
Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.47.2)
Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.13.3)
Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.16.0)
Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (4.14.1)
Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.12/dist-packages (from gradio) (0.35.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from gradio-client) (2025.3.0)
Requirement already satisfied: websockets<16.0,>=10.0 in /usr/local/lib/python3.12/dist-packages (from gradio-client) (15.0.1)
Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio) (3.10)
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0,>=0.24.1->gradio) (2025.8.3)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0,>=0.24.1->gradio) (1.0.9)
Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0,>=0.24.1->gradio) (0.16.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.33.5->gradio) (3.19.1)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.33.5->gradio) (2.32.4)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.33.5->gradio) (4.67.1)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.33.5->gradio) (1.1.7)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<2.12,>=2.0->gradio) (0.7.0)
Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<2.12,>=2.0->gradio) (2.33.2)
Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<2.12,>=2.0->gradio) (0.4.1)
Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (8.2.1)
Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (1.5.4)
Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.12/dist-packages (from typer<1.0,>=0.12->gradio) (13.9.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub<1.0,>=0.33.5->gradio) (3.4.3)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub<1.0,>=0.33.5->gradio) (2.5.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
Downloading gradio-5.43.1-py3-none-any.whl (59.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.6/59.6 MB 13.4 MB/s eta 0:00:00
Downloading gradio_client-1.12.1-py3-none-any.whl (324 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 324.6/324.6 kB 30.5 MB/s eta 0:00:00
Installing collected packages: gradio-client, gradio
  Attempting uninstall: gradio-client
    Found existing installation: gradio_client 1.7.2
    Uninstalling gradio_client-1.7.2:
      Successfully uninstalled gradio_client-1.7.2
  Attempting uninstall: gradio
    Found existing installation: gradio 5.18.0
    Uninstalling gradio-5.18.0:
      Successfully uninstalled gradio-5.18.0
Successfully installed gradio-5.43.1 gradio-client-1.12.1

Finalmente se ejecuta app.py. Contiene una interfaz de reconocimiento de imágenes. Este se ejecuta con colab como servidor. Con esto se pueden identificar las especies de este conjunto de datos que se encuentren en cualquier imagen.

In [21]:
%cd /content/
!python app.py
/content
theme_schema%400.0.4.json: 12.6kB [00:00, 49.5MB/s]
* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://d20b1e69686cd0b3e0.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
Keyboard interruption in main thread... closing server.
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/gradio/blocks.py", line 3158, in block_thread
    time.sleep(0.1)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/content/app.py", line 237, in <module>
    demo.queue().launch(share=True, inbrowser=True)
  File "/usr/local/lib/python3.12/dist-packages/gradio/blocks.py", line 3055, in launch
    self.block_thread()
  File "/usr/local/lib/python3.12/dist-packages/gradio/blocks.py", line 3162, in block_thread
    self.server.close()
  File "/usr/local/lib/python3.12/dist-packages/gradio/http_server.py", line 69, in close
    self.thread.join(timeout=5)
  File "/usr/lib/python3.12/threading.py", line 1153, in join
    self._wait_for_tstate_lock(timeout=max(timeout, 0))
  File "/usr/lib/python3.12/threading.py", line 1169, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
Killing tunnel 127.0.0.1:7860 <> https://d20b1e69686cd0b3e0.gradio.live

Ajusté el código para que las carpetas sólo se copien, por lo que aún tengo el modelo anterior disponible.

Ahora quiero comparar los modelos con un gráfico. Intento obtener el gráfico Precision-recall.

Ver el contenido de best_model_79.pth

In [24]:
import torch

# Cargar el archivo
checkpoint = torch.load("best_model_79.pth", map_location="cpu")

# Ver tipo y llaves
print(type(checkpoint))
print(checkpoint.keys() if isinstance(checkpoint, dict) else "No es diccionario")
<class 'dict'>
dict_keys(['epoch', 'model_state_dict'])

Ver los pesos del modelo

In [25]:
state_dict = checkpoint if not isinstance(checkpoint, dict) else checkpoint['model_state_dict']

for k, v in state_dict.items():
    print(k, v.shape)
backbone.body.conv1.weight torch.Size([64, 3, 7, 7])
backbone.body.bn1.weight torch.Size([64])
backbone.body.bn1.bias torch.Size([64])
backbone.body.bn1.running_mean torch.Size([64])
backbone.body.bn1.running_var torch.Size([64])
backbone.body.bn1.num_batches_tracked torch.Size([])
backbone.body.layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
backbone.body.layer1.0.bn1.weight torch.Size([64])
backbone.body.layer1.0.bn1.bias torch.Size([64])
backbone.body.layer1.0.bn1.running_mean torch.Size([64])
backbone.body.layer1.0.bn1.running_var torch.Size([64])
backbone.body.layer1.0.bn1.num_batches_tracked torch.Size([])
backbone.body.layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
backbone.body.layer1.0.bn2.weight torch.Size([64])
backbone.body.layer1.0.bn2.bias torch.Size([64])
backbone.body.layer1.0.bn2.running_mean torch.Size([64])
backbone.body.layer1.0.bn2.running_var torch.Size([64])
backbone.body.layer1.0.bn2.num_batches_tracked torch.Size([])
backbone.body.layer1.0.conv3.weight torch.Size([256, 64, 1, 1])
backbone.body.layer1.0.bn3.weight torch.Size([256])
backbone.body.layer1.0.bn3.bias torch.Size([256])
backbone.body.layer1.0.bn3.running_mean torch.Size([256])
backbone.body.layer1.0.bn3.running_var torch.Size([256])
backbone.body.layer1.0.bn3.num_batches_tracked torch.Size([])
backbone.body.layer1.0.downsample.0.weight torch.Size([256, 64, 1, 1])
backbone.body.layer1.0.downsample.1.weight torch.Size([256])
backbone.body.layer1.0.downsample.1.bias torch.Size([256])
backbone.body.layer1.0.downsample.1.running_mean torch.Size([256])
backbone.body.layer1.0.downsample.1.running_var torch.Size([256])
backbone.body.layer1.0.downsample.1.num_batches_tracked torch.Size([])
backbone.body.layer1.1.conv1.weight torch.Size([64, 256, 1, 1])
backbone.body.layer1.1.bn1.weight torch.Size([64])
backbone.body.layer1.1.bn1.bias torch.Size([64])
backbone.body.layer1.1.bn1.running_mean torch.Size([64])
backbone.body.layer1.1.bn1.running_var torch.Size([64])
backbone.body.layer1.1.bn1.num_batches_tracked torch.Size([])
backbone.body.layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
backbone.body.layer1.1.bn2.weight torch.Size([64])
backbone.body.layer1.1.bn2.bias torch.Size([64])
backbone.body.layer1.1.bn2.running_mean torch.Size([64])
backbone.body.layer1.1.bn2.running_var torch.Size([64])
backbone.body.layer1.1.bn2.num_batches_tracked torch.Size([])
backbone.body.layer1.1.conv3.weight torch.Size([256, 64, 1, 1])
backbone.body.layer1.1.bn3.weight torch.Size([256])
backbone.body.layer1.1.bn3.bias torch.Size([256])
backbone.body.layer1.1.bn3.running_mean torch.Size([256])
backbone.body.layer1.1.bn3.running_var torch.Size([256])
backbone.body.layer1.1.bn3.num_batches_tracked torch.Size([])
backbone.body.layer1.2.conv1.weight torch.Size([64, 256, 1, 1])
backbone.body.layer1.2.bn1.weight torch.Size([64])
backbone.body.layer1.2.bn1.bias torch.Size([64])
backbone.body.layer1.2.bn1.running_mean torch.Size([64])
backbone.body.layer1.2.bn1.running_var torch.Size([64])
backbone.body.layer1.2.bn1.num_batches_tracked torch.Size([])
backbone.body.layer1.2.conv2.weight torch.Size([64, 64, 3, 3])
backbone.body.layer1.2.bn2.weight torch.Size([64])
backbone.body.layer1.2.bn2.bias torch.Size([64])
backbone.body.layer1.2.bn2.running_mean torch.Size([64])
backbone.body.layer1.2.bn2.running_var torch.Size([64])
backbone.body.layer1.2.bn2.num_batches_tracked torch.Size([])
backbone.body.layer1.2.conv3.weight torch.Size([256, 64, 1, 1])
backbone.body.layer1.2.bn3.weight torch.Size([256])
backbone.body.layer1.2.bn3.bias torch.Size([256])
backbone.body.layer1.2.bn3.running_mean torch.Size([256])
backbone.body.layer1.2.bn3.running_var torch.Size([256])
backbone.body.layer1.2.bn3.num_batches_tracked torch.Size([])
backbone.body.layer2.0.conv1.weight torch.Size([128, 256, 1, 1])
backbone.body.layer2.0.bn1.weight torch.Size([128])
backbone.body.layer2.0.bn1.bias torch.Size([128])
backbone.body.layer2.0.bn1.running_mean torch.Size([128])
backbone.body.layer2.0.bn1.running_var torch.Size([128])
backbone.body.layer2.0.bn1.num_batches_tracked torch.Size([])
backbone.body.layer2.0.conv2.weight torch.Size([128, 128, 3, 3])
backbone.body.layer2.0.bn2.weight torch.Size([128])
backbone.body.layer2.0.bn2.bias torch.Size([128])
backbone.body.layer2.0.bn2.running_mean torch.Size([128])
backbone.body.layer2.0.bn2.running_var torch.Size([128])
backbone.body.layer2.0.bn2.num_batches_tracked torch.Size([])
backbone.body.layer2.0.conv3.weight torch.Size([512, 128, 1, 1])
backbone.body.layer2.0.bn3.weight torch.Size([512])
backbone.body.layer2.0.bn3.bias torch.Size([512])
backbone.body.layer2.0.bn3.running_mean torch.Size([512])
backbone.body.layer2.0.bn3.running_var torch.Size([512])
backbone.body.layer2.0.bn3.num_batches_tracked torch.Size([])
backbone.body.layer2.0.downsample.0.weight torch.Size([512, 256, 1, 1])
backbone.body.layer2.0.downsample.1.weight torch.Size([512])
backbone.body.layer2.0.downsample.1.bias torch.Size([512])
backbone.body.layer2.0.downsample.1.running_mean torch.Size([512])
backbone.body.layer2.0.downsample.1.running_var torch.Size([512])
backbone.body.layer2.0.downsample.1.num_batches_tracked torch.Size([])
backbone.body.layer2.1.conv1.weight torch.Size([128, 512, 1, 1])
backbone.body.layer2.1.bn1.weight torch.Size([128])
backbone.body.layer2.1.bn1.bias torch.Size([128])
backbone.body.layer2.1.bn1.running_mean torch.Size([128])
backbone.body.layer2.1.bn1.running_var torch.Size([128])
backbone.body.layer2.1.bn1.num_batches_tracked torch.Size([])
backbone.body.layer2.1.conv2.weight torch.Size([128, 128, 3, 3])
backbone.body.layer2.1.bn2.weight torch.Size([128])
backbone.body.layer2.1.bn2.bias torch.Size([128])
backbone.body.layer2.1.bn2.running_mean torch.Size([128])
backbone.body.layer2.1.bn2.running_var torch.Size([128])
backbone.body.layer2.1.bn2.num_batches_tracked torch.Size([])
backbone.body.layer2.1.conv3.weight torch.Size([512, 128, 1, 1])
backbone.body.layer2.1.bn3.weight torch.Size([512])
backbone.body.layer2.1.bn3.bias torch.Size([512])
backbone.body.layer2.1.bn3.running_mean torch.Size([512])
backbone.body.layer2.1.bn3.running_var torch.Size([512])
backbone.body.layer2.1.bn3.num_batches_tracked torch.Size([])
backbone.body.layer2.2.conv1.weight torch.Size([128, 512, 1, 1])
backbone.body.layer2.2.bn1.weight torch.Size([128])
backbone.body.layer2.2.bn1.bias torch.Size([128])
backbone.body.layer2.2.bn1.running_mean torch.Size([128])
backbone.body.layer2.2.bn1.running_var torch.Size([128])
backbone.body.layer2.2.bn1.num_batches_tracked torch.Size([])
backbone.body.layer2.2.conv2.weight torch.Size([128, 128, 3, 3])
backbone.body.layer2.2.bn2.weight torch.Size([128])
backbone.body.layer2.2.bn2.bias torch.Size([128])
backbone.body.layer2.2.bn2.running_mean torch.Size([128])
backbone.body.layer2.2.bn2.running_var torch.Size([128])
backbone.body.layer2.2.bn2.num_batches_tracked torch.Size([])
backbone.body.layer2.2.conv3.weight torch.Size([512, 128, 1, 1])
backbone.body.layer2.2.bn3.weight torch.Size([512])
backbone.body.layer2.2.bn3.bias torch.Size([512])
backbone.body.layer2.2.bn3.running_mean torch.Size([512])
backbone.body.layer2.2.bn3.running_var torch.Size([512])
backbone.body.layer2.2.bn3.num_batches_tracked torch.Size([])
backbone.body.layer2.3.conv1.weight torch.Size([128, 512, 1, 1])
backbone.body.layer2.3.bn1.weight torch.Size([128])
backbone.body.layer2.3.bn1.bias torch.Size([128])
backbone.body.layer2.3.bn1.running_mean torch.Size([128])
backbone.body.layer2.3.bn1.running_var torch.Size([128])
backbone.body.layer2.3.bn1.num_batches_tracked torch.Size([])
backbone.body.layer2.3.conv2.weight torch.Size([128, 128, 3, 3])
backbone.body.layer2.3.bn2.weight torch.Size([128])
backbone.body.layer2.3.bn2.bias torch.Size([128])
backbone.body.layer2.3.bn2.running_mean torch.Size([128])
backbone.body.layer2.3.bn2.running_var torch.Size([128])
backbone.body.layer2.3.bn2.num_batches_tracked torch.Size([])
backbone.body.layer2.3.conv3.weight torch.Size([512, 128, 1, 1])
backbone.body.layer2.3.bn3.weight torch.Size([512])
backbone.body.layer2.3.bn3.bias torch.Size([512])
backbone.body.layer2.3.bn3.running_mean torch.Size([512])
backbone.body.layer2.3.bn3.running_var torch.Size([512])
backbone.body.layer2.3.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.0.conv1.weight torch.Size([256, 512, 1, 1])
backbone.body.layer3.0.bn1.weight torch.Size([256])
backbone.body.layer3.0.bn1.bias torch.Size([256])
backbone.body.layer3.0.bn1.running_mean torch.Size([256])
backbone.body.layer3.0.bn1.running_var torch.Size([256])
backbone.body.layer3.0.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.0.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.0.bn2.weight torch.Size([256])
backbone.body.layer3.0.bn2.bias torch.Size([256])
backbone.body.layer3.0.bn2.running_mean torch.Size([256])
backbone.body.layer3.0.bn2.running_var torch.Size([256])
backbone.body.layer3.0.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.0.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.0.bn3.weight torch.Size([1024])
backbone.body.layer3.0.bn3.bias torch.Size([1024])
backbone.body.layer3.0.bn3.running_mean torch.Size([1024])
backbone.body.layer3.0.bn3.running_var torch.Size([1024])
backbone.body.layer3.0.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.0.downsample.0.weight torch.Size([1024, 512, 1, 1])
backbone.body.layer3.0.downsample.1.weight torch.Size([1024])
backbone.body.layer3.0.downsample.1.bias torch.Size([1024])
backbone.body.layer3.0.downsample.1.running_mean torch.Size([1024])
backbone.body.layer3.0.downsample.1.running_var torch.Size([1024])
backbone.body.layer3.0.downsample.1.num_batches_tracked torch.Size([])
backbone.body.layer3.1.conv1.weight torch.Size([256, 1024, 1, 1])
backbone.body.layer3.1.bn1.weight torch.Size([256])
backbone.body.layer3.1.bn1.bias torch.Size([256])
backbone.body.layer3.1.bn1.running_mean torch.Size([256])
backbone.body.layer3.1.bn1.running_var torch.Size([256])
backbone.body.layer3.1.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.1.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.1.bn2.weight torch.Size([256])
backbone.body.layer3.1.bn2.bias torch.Size([256])
backbone.body.layer3.1.bn2.running_mean torch.Size([256])
backbone.body.layer3.1.bn2.running_var torch.Size([256])
backbone.body.layer3.1.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.1.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.1.bn3.weight torch.Size([1024])
backbone.body.layer3.1.bn3.bias torch.Size([1024])
backbone.body.layer3.1.bn3.running_mean torch.Size([1024])
backbone.body.layer3.1.bn3.running_var torch.Size([1024])
backbone.body.layer3.1.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.2.conv1.weight torch.Size([256, 1024, 1, 1])
backbone.body.layer3.2.bn1.weight torch.Size([256])
backbone.body.layer3.2.bn1.bias torch.Size([256])
backbone.body.layer3.2.bn1.running_mean torch.Size([256])
backbone.body.layer3.2.bn1.running_var torch.Size([256])
backbone.body.layer3.2.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.2.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.2.bn2.weight torch.Size([256])
backbone.body.layer3.2.bn2.bias torch.Size([256])
backbone.body.layer3.2.bn2.running_mean torch.Size([256])
backbone.body.layer3.2.bn2.running_var torch.Size([256])
backbone.body.layer3.2.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.2.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.2.bn3.weight torch.Size([1024])
backbone.body.layer3.2.bn3.bias torch.Size([1024])
backbone.body.layer3.2.bn3.running_mean torch.Size([1024])
backbone.body.layer3.2.bn3.running_var torch.Size([1024])
backbone.body.layer3.2.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.3.conv1.weight torch.Size([256, 1024, 1, 1])
backbone.body.layer3.3.bn1.weight torch.Size([256])
backbone.body.layer3.3.bn1.bias torch.Size([256])
backbone.body.layer3.3.bn1.running_mean torch.Size([256])
backbone.body.layer3.3.bn1.running_var torch.Size([256])
backbone.body.layer3.3.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.3.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.3.bn2.weight torch.Size([256])
backbone.body.layer3.3.bn2.bias torch.Size([256])
backbone.body.layer3.3.bn2.running_mean torch.Size([256])
backbone.body.layer3.3.bn2.running_var torch.Size([256])
backbone.body.layer3.3.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.3.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.3.bn3.weight torch.Size([1024])
backbone.body.layer3.3.bn3.bias torch.Size([1024])
backbone.body.layer3.3.bn3.running_mean torch.Size([1024])
backbone.body.layer3.3.bn3.running_var torch.Size([1024])
backbone.body.layer3.3.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.4.conv1.weight torch.Size([256, 1024, 1, 1])
backbone.body.layer3.4.bn1.weight torch.Size([256])
backbone.body.layer3.4.bn1.bias torch.Size([256])
backbone.body.layer3.4.bn1.running_mean torch.Size([256])
backbone.body.layer3.4.bn1.running_var torch.Size([256])
backbone.body.layer3.4.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.4.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.4.bn2.weight torch.Size([256])
backbone.body.layer3.4.bn2.bias torch.Size([256])
backbone.body.layer3.4.bn2.running_mean torch.Size([256])
backbone.body.layer3.4.bn2.running_var torch.Size([256])
backbone.body.layer3.4.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.4.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.4.bn3.weight torch.Size([1024])
backbone.body.layer3.4.bn3.bias torch.Size([1024])
backbone.body.layer3.4.bn3.running_mean torch.Size([1024])
backbone.body.layer3.4.bn3.running_var torch.Size([1024])
backbone.body.layer3.4.bn3.num_batches_tracked torch.Size([])
backbone.body.layer3.5.conv1.weight torch.Size([256, 1024, 1, 1])
backbone.body.layer3.5.bn1.weight torch.Size([256])
backbone.body.layer3.5.bn1.bias torch.Size([256])
backbone.body.layer3.5.bn1.running_mean torch.Size([256])
backbone.body.layer3.5.bn1.running_var torch.Size([256])
backbone.body.layer3.5.bn1.num_batches_tracked torch.Size([])
backbone.body.layer3.5.conv2.weight torch.Size([256, 256, 3, 3])
backbone.body.layer3.5.bn2.weight torch.Size([256])
backbone.body.layer3.5.bn2.bias torch.Size([256])
backbone.body.layer3.5.bn2.running_mean torch.Size([256])
backbone.body.layer3.5.bn2.running_var torch.Size([256])
backbone.body.layer3.5.bn2.num_batches_tracked torch.Size([])
backbone.body.layer3.5.conv3.weight torch.Size([1024, 256, 1, 1])
backbone.body.layer3.5.bn3.weight torch.Size([1024])
backbone.body.layer3.5.bn3.bias torch.Size([1024])
backbone.body.layer3.5.bn3.running_mean torch.Size([1024])
backbone.body.layer3.5.bn3.running_var torch.Size([1024])
backbone.body.layer3.5.bn3.num_batches_tracked torch.Size([])
backbone.body.layer4.0.conv1.weight torch.Size([512, 1024, 1, 1])
backbone.body.layer4.0.bn1.weight torch.Size([512])
backbone.body.layer4.0.bn1.bias torch.Size([512])
backbone.body.layer4.0.bn1.running_mean torch.Size([512])
backbone.body.layer4.0.bn1.running_var torch.Size([512])
backbone.body.layer4.0.bn1.num_batches_tracked torch.Size([])
backbone.body.layer4.0.conv2.weight torch.Size([512, 512, 3, 3])
backbone.body.layer4.0.bn2.weight torch.Size([512])
backbone.body.layer4.0.bn2.bias torch.Size([512])
backbone.body.layer4.0.bn2.running_mean torch.Size([512])
backbone.body.layer4.0.bn2.running_var torch.Size([512])
backbone.body.layer4.0.bn2.num_batches_tracked torch.Size([])
backbone.body.layer4.0.conv3.weight torch.Size([2048, 512, 1, 1])
backbone.body.layer4.0.bn3.weight torch.Size([2048])
backbone.body.layer4.0.bn3.bias torch.Size([2048])
backbone.body.layer4.0.bn3.running_mean torch.Size([2048])
backbone.body.layer4.0.bn3.running_var torch.Size([2048])
backbone.body.layer4.0.bn3.num_batches_tracked torch.Size([])
backbone.body.layer4.0.downsample.0.weight torch.Size([2048, 1024, 1, 1])
backbone.body.layer4.0.downsample.1.weight torch.Size([2048])
backbone.body.layer4.0.downsample.1.bias torch.Size([2048])
backbone.body.layer4.0.downsample.1.running_mean torch.Size([2048])
backbone.body.layer4.0.downsample.1.running_var torch.Size([2048])
backbone.body.layer4.0.downsample.1.num_batches_tracked torch.Size([])
backbone.body.layer4.1.conv1.weight torch.Size([512, 2048, 1, 1])
backbone.body.layer4.1.bn1.weight torch.Size([512])
backbone.body.layer4.1.bn1.bias torch.Size([512])
backbone.body.layer4.1.bn1.running_mean torch.Size([512])
backbone.body.layer4.1.bn1.running_var torch.Size([512])
backbone.body.layer4.1.bn1.num_batches_tracked torch.Size([])
backbone.body.layer4.1.conv2.weight torch.Size([512, 512, 3, 3])
backbone.body.layer4.1.bn2.weight torch.Size([512])
backbone.body.layer4.1.bn2.bias torch.Size([512])
backbone.body.layer4.1.bn2.running_mean torch.Size([512])
backbone.body.layer4.1.bn2.running_var torch.Size([512])
backbone.body.layer4.1.bn2.num_batches_tracked torch.Size([])
backbone.body.layer4.1.conv3.weight torch.Size([2048, 512, 1, 1])
backbone.body.layer4.1.bn3.weight torch.Size([2048])
backbone.body.layer4.1.bn3.bias torch.Size([2048])
backbone.body.layer4.1.bn3.running_mean torch.Size([2048])
backbone.body.layer4.1.bn3.running_var torch.Size([2048])
backbone.body.layer4.1.bn3.num_batches_tracked torch.Size([])
backbone.body.layer4.2.conv1.weight torch.Size([512, 2048, 1, 1])
backbone.body.layer4.2.bn1.weight torch.Size([512])
backbone.body.layer4.2.bn1.bias torch.Size([512])
backbone.body.layer4.2.bn1.running_mean torch.Size([512])
backbone.body.layer4.2.bn1.running_var torch.Size([512])
backbone.body.layer4.2.bn1.num_batches_tracked torch.Size([])
backbone.body.layer4.2.conv2.weight torch.Size([512, 512, 3, 3])
backbone.body.layer4.2.bn2.weight torch.Size([512])
backbone.body.layer4.2.bn2.bias torch.Size([512])
backbone.body.layer4.2.bn2.running_mean torch.Size([512])
backbone.body.layer4.2.bn2.running_var torch.Size([512])
backbone.body.layer4.2.bn2.num_batches_tracked torch.Size([])
backbone.body.layer4.2.conv3.weight torch.Size([2048, 512, 1, 1])
backbone.body.layer4.2.bn3.weight torch.Size([2048])
backbone.body.layer4.2.bn3.bias torch.Size([2048])
backbone.body.layer4.2.bn3.running_mean torch.Size([2048])
backbone.body.layer4.2.bn3.running_var torch.Size([2048])
backbone.body.layer4.2.bn3.num_batches_tracked torch.Size([])
backbone.fpn.inner_blocks.0.0.weight torch.Size([256, 512, 1, 1])
backbone.fpn.inner_blocks.0.0.bias torch.Size([256])
backbone.fpn.inner_blocks.1.0.weight torch.Size([256, 1024, 1, 1])
backbone.fpn.inner_blocks.1.0.bias torch.Size([256])
backbone.fpn.inner_blocks.2.0.weight torch.Size([256, 2048, 1, 1])
backbone.fpn.inner_blocks.2.0.bias torch.Size([256])
backbone.fpn.layer_blocks.0.0.weight torch.Size([256, 256, 3, 3])
backbone.fpn.layer_blocks.0.0.bias torch.Size([256])
backbone.fpn.layer_blocks.1.0.weight torch.Size([256, 256, 3, 3])
backbone.fpn.layer_blocks.1.0.bias torch.Size([256])
backbone.fpn.layer_blocks.2.0.weight torch.Size([256, 256, 3, 3])
backbone.fpn.layer_blocks.2.0.bias torch.Size([256])
backbone.fpn.extra_blocks.p6.weight torch.Size([256, 2048, 3, 3])
backbone.fpn.extra_blocks.p6.bias torch.Size([256])
backbone.fpn.extra_blocks.p7.weight torch.Size([256, 256, 3, 3])
backbone.fpn.extra_blocks.p7.bias torch.Size([256])
head.classification_head.conv.0.0.weight torch.Size([256, 256, 3, 3])
head.classification_head.conv.0.1.weight torch.Size([256])
head.classification_head.conv.0.1.bias torch.Size([256])
head.classification_head.conv.1.0.weight torch.Size([256, 256, 3, 3])
head.classification_head.conv.1.1.weight torch.Size([256])
head.classification_head.conv.1.1.bias torch.Size([256])
head.classification_head.conv.2.0.weight torch.Size([256, 256, 3, 3])
head.classification_head.conv.2.1.weight torch.Size([256])
head.classification_head.conv.2.1.bias torch.Size([256])
head.classification_head.conv.3.0.weight torch.Size([256, 256, 3, 3])
head.classification_head.conv.3.1.weight torch.Size([256])
head.classification_head.conv.3.1.bias torch.Size([256])
head.classification_head.cls_logits.weight torch.Size([45, 256, 3, 3])
head.classification_head.cls_logits.bias torch.Size([45])
head.regression_head.conv.0.0.weight torch.Size([256, 256, 3, 3])
head.regression_head.conv.0.1.weight torch.Size([256])
head.regression_head.conv.0.1.bias torch.Size([256])
head.regression_head.conv.1.0.weight torch.Size([256, 256, 3, 3])
head.regression_head.conv.1.1.weight torch.Size([256])
head.regression_head.conv.1.1.bias torch.Size([256])
head.regression_head.conv.2.0.weight torch.Size([256, 256, 3, 3])
head.regression_head.conv.2.1.weight torch.Size([256])
head.regression_head.conv.2.1.bias torch.Size([256])
head.regression_head.conv.3.0.weight torch.Size([256, 256, 3, 3])
head.regression_head.conv.3.1.weight torch.Size([256])
head.regression_head.conv.3.1.bias torch.Size([256])
head.regression_head.bbox_reg.weight torch.Size([36, 256, 3, 3])
head.regression_head.bbox_reg.bias torch.Size([36])

Cargar el modelo para usarlo

In [26]:
from model import create_model  #función de creación de modelo
from config import NUM_CLASSES, DEVICE

model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE).eval()
Out[26]:
RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-2): 3 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelP6P7(
        (p6): Conv2d(2048, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
    )
  )
  (anchor_generator): AnchorGenerator()
  (head): RetinaNetHead(
    (classification_head): RetinaNetClassificationHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (regression_head): RetinaNetRegressionHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
        )
      )
      (bbox_reg): Conv2d(256, 36, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
)

usar customdatasetd

In [27]:
from torch.utils.data import DataLoader
from datasets import create_valid_dataset  #
from config import BATCH_SIZE, DEVICE

valid_dataset = create_valid_dataset("/content/datasets/african-wildlife/val")
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
/usr/local/lib/python3.12/dist-packages/albumentations/core/composition.py:331: UserWarning: Got processor for bboxes, but no transform to process it.
  self._set_keys()

Tras error, acceder a los labels de targets.

In [30]:
print(targets)
{'boxes': tensor([[[358.4000, 424.3309, 395.2000, 425.3309]]]), 'labels': tensor([[1]]), 'area': tensor([[36.8000]]), 'iscrowd': tensor([[0]]), 'image_id': tensor([[0]])}
In [31]:
true_labels = targets['labels'].squeeze(0).cpu().numpy()
In [32]:
true_boxes = targets['boxes'].squeeze(0).cpu().numpy()

Eliminar error Keyerror

In [33]:
all_scores = []
all_labels = []

for images, targets in valid_loader:
    images = images.to(DEVICE)
    with torch.no_grad():
        outputs = model(images)

    pred_scores = outputs[0]['scores'].cpu().numpy()
    pred_labels = outputs[0]['labels'].cpu().numpy()

    true_labels = targets['labels'].squeeze(0).cpu().numpy()

    all_scores.extend(pred_scores)
    all_labels.extend(true_labels)

Separar clases (especies)

In [34]:
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import numpy as np

# Lista de clases
classes = ["buffalo", "elephant", "rhino", "zebra"]
n_classes = len(classes)

# Binarizar etiquetas reales
y_true_bin = label_binarize(all_labels, classes=range(1, n_classes+1))  # ajusta según tu index

Calcular Precision-Recall y AP por clase

In [35]:
plt.figure(figsize=(8,6))

for i, class_name in enumerate(classes):
    # Scores para esta clase
    y_scores_class = np.array([s if l == i+1 else 0 for s,l in zip(all_scores, all_labels)])
    y_true_class = y_true_bin[:, i]

    precision, recall, _ = precision_recall_curve(y_true_class, y_scores_class)
    ap = average_precision_score(y_true_class, y_scores_class)

    plt.plot(recall, precision, label=f"{class_name} (AP={ap:.2f})")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall curve por clase")
plt.legend()
plt.show()
No description has been provided for this image

El gráfico no salió bien, asume que el modelo es perfecto. Hice varias pruebas pero al final me rendí X_X

A partir de este punto sólo quedan intentos fallidos. No los remuevo para seguir intentando corregirlos algún día.

In [38]:
for images, targets in valid_loader:
    images = list(img.to(DEVICE) for img in images)

    with torch.no_grad():
        outputs = model(images)

    # Aquí targets es lista de diccionarios
    for output, target in zip(outputs, targets):
        # output es diccionario de predicciones
        scores = output['scores'].cpu().numpy()
        labels_pred = output['labels'].cpu().numpy()

        # target es diccionario de ground-truth
        labels_true = target['labels'].cpu().numpy().flatten()  # <- flatten porque puede ser [[1]]

        all_scores.extend(scores)
        all_labels.extend(labels_true)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-2908744283.py in <cell line: 0>()
     12 
     13         # target es diccionario de ground-truth
---> 14         labels_true = target['labels'].cpu().numpy().flatten()  # <- flatten porque puede ser [[1]]
     15 
     16         all_scores.extend(scores)

TypeError: string indices must be integers, not 'str'

No me sale. Intento la matriz de confusión (tampoco salió)

Exportación a Onnx.

In [ ]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
In [ ]:
!pip install onnx
Collecting onnx
  Downloading onnx-1.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.12/dist-packages (from onnx) (2.0.2)
Requirement already satisfied: protobuf>=4.25.1 in /usr/local/lib/python3.12/dist-packages (from onnx) (5.29.5)
Requirement already satisfied: typing_extensions>=4.7.1 in /usr/local/lib/python3.12/dist-packages (from onnx) (4.14.1)
Downloading onnx-1.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.6/17.6 MB 110.7 MB/s eta 0:00:00
Installing collected packages: onnx
Successfully installed onnx-1.18.0
In [ ]:
!pip install onnxruntime
Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: flatbuffers in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (25.2.10)
Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (2.0.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (25.0)
Requirement already satisfied: protobuf in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (5.29.5)
Requirement already satisfied: sympy in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (1.13.1)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy->onnxruntime) (1.3.0)
Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.5/16.5 MB 109.8 MB/s eta 0:00:00
Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.0/46.0 kB 4.6 MB/s eta 0:00:00
Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 9.8 MB/s eta 0:00:00
Installing collected packages: humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.22.1
In [ ]:
!python export.py
/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py:4624: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  * torch.tensor(scale_factors[i], dtype=torch.float32)
/usr/local/lib/python3.12/dist-packages/torchvision/ops/boxes.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.12/dist-packages/torchvision/ops/boxes.py:168: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.12/dist-packages/torchvision/models/detection/transform.py:308: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(s, dtype=torch.float32, device=boxes.device)
/usr/local/lib/python3.12/dist-packages/torchvision/models/detection/transform.py:309: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
Model exported to outputs/retinanet.onnx
In [ ]:
import onnxruntime as ort
import torch
import numpy as np
import cv2

# Ruta del modelo ONNX
onnx_model_path = "outputs/retinanet.onnx"

# Crear sesión ONNX
ort_session = ort.InferenceSession(onnx_model_path)

# Prepara una imagen de prueba
image_path = "/content/datasets/african-wildlife/train/images/1 (106).jpg"
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))  # tamaño usado en export.py
img = img.astype(np.float32) / 255.0

# Cambiar HWC -> CHW
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, axis=0)  # añadir batch dimension: 1x3x640x640

# Convertir a float32
img = img.astype(np.float32)

# Ejecutar inferencia
outputs = ort_session.run(None, {"images": img})

# outputs es una lista: [boxes, scores, labels]
boxes, scores, labels = outputs
print("Boxes:", boxes.shape)
print("Scores:", scores.shape)
print("Labels:", labels.shape)

# Ejemplo: mostrar cajas con OpenCV
for i in range(len(boxes)):
    if scores[i] > 0.5:  # umbral de confianza
        x1, y1, x2, y2 = boxes[i].astype(int)
        label = int(labels[i])
        cv2.rectangle(img[0].transpose(1,2,0) * 255, (x1, y1), (x2, y2), (0, 0, 255), 2)
Boxes: (4, 4)
Scores: (4,)
Labels: (4,)
In [ ]:
import onnxruntime as ort
import cv2
import numpy as np

# -----------------------------
# Config
# -----------------------------
onnx_model_path = "outputs/retinanet.onnx"
image_path = "/content/datasets/african-wildlife/train/images/1 (106).jpg"
input_size = 640  # mismo tamaño que usaste al exportar
CLASSES = ["buffalo", "zebra", "elephant", "rhino"]  # reemplaza con tus clases reales

# -----------------------------
# 1️⃣ Cargar la imagen
# -----------------------------
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (input_size, input_size))
img_normalized = img_resized.astype(np.float32) / 255.0

# Convertir a (1, 3, H, W)
input_tensor = np.transpose(img_normalized, (2, 0, 1))[np.newaxis, :, :, :]

# -----------------------------
# 2️⃣ Cargar el modelo ONNX
# -----------------------------
ort_session = ort.InferenceSession(onnx_model_path)

# -----------------------------
# 3️⃣ Realizar inferencia
# -----------------------------
outputs = ort_session.run(None, {"images": input_tensor})

boxes = outputs[0]   # shape: [num_detections, 4]
scores = outputs[1]  # shape: [num_detections]
labels = outputs[2]  # shape: [num_detections]

# -----------------------------
# 4️⃣ Dibujar resultados
# -----------------------------
for box, score, label in zip(boxes, scores, labels):
    if score < 0.3:  # umbral de confianza
        continue

    x1, y1, x2, y2 = box
    x1 = int(x1 * img.shape[1] / input_size)
    x2 = int(x2 * img.shape[1] / input_size)
    y1 = int(y1 * img.shape[0] / input_size)
    y2 = int(y2 * img.shape[0] / input_size)

    class_name = CLASSES[int(label) - 1]  # si tu clase empieza en 1
    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.putText(img, f"{class_name}: {score:.2f}", (x1, max(y1-10, 0)),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

# Mostrar imagen
cv2.imshow("Detections", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
---------------------------------------------------------------------------
DisabledFunctionError                     Traceback (most recent call last)
/tmp/ipython-input-709664407.py in <cell line: 0>()
     55 
     56 # Mostrar imagen
---> 57 cv2.imshow("Detections", img)
     58 cv2.waitKey(0)
     59 cv2.destroyAllWindows()

/usr/local/lib/python3.12/dist-packages/google/colab/_import_hooks/_cv2.py in wrapped(*args, **kwargs)
     48   def wrapped(*args, **kwargs):
     49     if not os.environ.get(env_var, False):
---> 50       raise DisabledFunctionError(message, name or func.__name__)
     51     return func(*args, **kwargs)
     52 

DisabledFunctionError: cv2.imshow() is disabled in Colab, because it causes Jupyter sessions
to crash; see https://github.com/jupyter/notebook/issues/3935.
As a substitution, consider using
  from google.colab.patches import cv2_imshow
In [ ]:
import onnxruntime as ort
import cv2
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# Config
# -----------------------------
onnx_model_path = "outputs/retinanet.onnx"
image_path = "/content/datasets/african-wildlife/train/images/1 (106).jpg"
input_size = 640  # mismo tamaño que usaste al exportar
CLASSES = ["buffalo", "zebra", "elephant", "rhino"]  #s
CONF_THRESH = 0.3  # umbral de confianza

# -----------------------------
# 1️⃣ Cargar y preparar la imagen
# -----------------------------
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (input_size, input_size))
img_normalized = img_resized.astype(np.float32) / 255.0

# Convertir a (1, 3, H, W) para ONNX
input_tensor = np.transpose(img_normalized, (2, 0, 1))[np.newaxis, :, :, :]

# -----------------------------
# 2️⃣ Cargar el modelo ONNX
# -----------------------------
ort_session = ort.InferenceSession(onnx_model_path)

# -----------------------------
# 3️⃣ Inferencia
# -----------------------------
outputs = ort_session.run(None, {"images": input_tensor})
boxes = outputs[0]   # shape: [num_detections, 4]
scores = outputs[1]  # shape: [num_detections]
labels = outputs[2]  # shape: [num_detections]

# -----------------------------
# 4️⃣ Dibujar resultados
# -----------------------------
img_plot = img_rgb.copy()

for box, score, label in zip(boxes, scores, labels):
    if score < CONF_THRESH:
        continue

    x1, y1, x2, y2 = box
    # Reescalar a tamaño original
    x1 = int(x1 * img.shape[1] / input_size)
    x2 = int(x2 * img.shape[1] / input_size)
    y1 = int(y1 * img.shape[0] / input_size)
    y2 = int(y2 * img.shape[0] / input_size)

    class_name = CLASSES[int(label) - 1]  # si tus clases empiezan en 1
    cv2.rectangle(img_plot, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.putText(img_plot, f"{class_name}: {score:.2f}", (x1, max(y1-10, 0)),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

# -----------------------------
# 5️⃣ Mostrar imagen en Jupyter/Colab
# -----------------------------
plt.figure(figsize=(12, 8))
plt.imshow(img_plot)
plt.axis('off')
plt.title("Detections ONNX RetinaNet")
plt.show()
No description has been provided for this image
In [ ]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torch

metric = MeanAveragePrecision(class_metrics=True)  # nos da AP por clase

# Supongamos que tenemos listas de predicciones y targets como en tu entrenamiento
preds = [...]   # lista de dicts: {"boxes":..., "scores":..., "labels":...}
targets = [...] # lista de dicts: {"boxes":..., "labels":...}

metric.update(preds, targets)
results = metric.compute()
print(results)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-58285722.py in <cell line: 0>()
      8 targets = [...] # lista de dicts: {"boxes":..., "labels":...}
      9 
---> 10 metric.update(preds, targets)
     11 results = metric.compute()
     12 print(results)

/usr/local/lib/python3.12/dist-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs)
    547             with torch.set_grad_enabled(self._enable_grad):
    548                 try:
--> 549                     update(*args, **kwargs)
    550                 except RuntimeError as err:
    551                     if "Expected all tensors to be on" in str(err):

/usr/local/lib/python3.12/dist-packages/torchmetrics/detection/mean_ap.py in update(self, preds, target)
    528 
    529         """
--> 530         _input_validator(preds, target, iou_type=self.iou_type)
    531 
    532         for item in preds:

/usr/local/lib/python3.12/dist-packages/torchmetrics/detection/helpers.py in _input_validator(preds, targets, iou_type, ignore_score)
     64 
     65     for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []):
---> 66         if any(k not in p for p in preds):
     67             raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")
     68 

/usr/local/lib/python3.12/dist-packages/torchmetrics/detection/helpers.py in <genexpr>(.0)
     64 
     65     for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []):
---> 66         if any(k not in p for p in preds):
     67             raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")
     68 

TypeError: argument of type 'ellipsis' is not iterable
In [ ]:
import onnxruntime as ort
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

# -----------------------------
# Config
# -----------------------------
onnx_model_path = "outputs/retinanet.onnx"
input_size = 640
CLASSES = ["buffalo", "zebra", "elephant", "rhino"]  #clases
CONF_THRESH = 0.0  # para PR curve consideramos todas las detecciones
DEVICE = 'cpu'  # onnxruntime usa CPU o GPU automáticamente

# -----------------------------
# Cargar ONNX
# -----------------------------
ort_session = ort.InferenceSession(onnx_model_path)

# -----------------------------
# Datos de prueba (ejemplo con 1 imagen)
# En la práctica, recorre todo tu conjunto de validación
# -----------------------------
image_path = "/content/datasets/african-wildlife/val/images/1 (102).jpg"
gt_labels = np.array([1, 2])  # Ejemplo: clases reales presentes (cambia según tu label)
gt_boxes = np.array([[50, 60, 200, 300], [150, 200, 350, 400]])  # xyxy

img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (input_size, input_size))
img_normalized = img_resized.astype(np.float32) / 255.0
input_tensor = np.transpose(img_normalized, (2, 0, 1))[np.newaxis, :, :, :]

# -----------------------------
# Inferencia ONNX
# -----------------------------
outputs = ort_session.run(None, {"images": input_tensor})
pred_boxes = outputs[0]   # [N,4]
pred_scores = outputs[1]  # [N]
pred_labels = outputs[2]  # [N]

# -----------------------------
# Flatten para PR curve
# Nota: normalmente usamos todas las imágenes y sus etiquetas
# -----------------------------
y_true = []
y_scores = []

for label in CLASSES:
    # Convierte las clases a 0/1 para cada clase
    y_true_class = (gt_labels == (CLASSES.index(label)+1)).astype(int)
    y_scores_class = pred_scores[(pred_labels == (CLASSES.index(label)+1))]
    if len(y_scores_class) == 0:
        y_scores_class = np.array([0]*len(y_true_class))
    y_true.extend(y_true_class.tolist())
    y_scores.extend(y_scores_class.tolist())

# -----------------------------
# Precision-Recall curve
# -----------------------------
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
ap = average_precision_score(y_true, y_scores)

plt.figure(figsize=(8,6))
plt.plot(recall, precision, marker='.', label=f'AP={ap:.3f}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve')
plt.legend()
plt.grid(True)
plt.show()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipython-input-2268590998.py in <cell line: 0>()
     60 # Precision-Recall curve
     61 # -----------------------------
---> 62 precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
     63 ap = average_precision_score(y_true, y_scores)
     64 

/usr/local/lib/python3.12/dist-packages/sklearn/utils/_param_validation.py in wrapper(*args, **kwargs)
    214                     )
    215                 ):
--> 216                     return func(*args, **kwargs)
    217             except InvalidParameterError as e:
    218                 # When the function is just a wrapper around an estimator, we allow

/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_ranking.py in precision_recall_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate, probas_pred)
   1003         y_score = probas_pred
   1004 
-> 1005     fps, tps, thresholds = _binary_clf_curve(
   1006         y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
   1007     )

/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    818         raise ValueError("{0} format is not supported".format(y_type))
    819 
--> 820     check_consistent_length(y_true, y_score, sample_weight)
    821     y_true = column_or_1d(y_true)
    822     y_score = column_or_1d(y_score)

/usr/local/lib/python3.12/dist-packages/sklearn/utils/validation.py in check_consistent_length(*arrays)
    473     uniques = np.unique(lengths)
    474     if len(uniques) > 1:
--> 475         raise ValueError(
    476             "Found input variables with inconsistent numbers of samples: %r"
    477             % [int(l) for l in lengths]

ValueError: Found input variables with inconsistent numbers of samples: [8, 6]

No salió de nuevo, intento cargar el visualization data.py descargado

In [ ]:
!python inf_video.py
usage: inf_video.py [-h] -i INPUT [--imgsz IMGSZ] [--threshold THRESHOLD]
inf_video.py: error: the following arguments are required: -i/--input