Collate Functions

Some custom collate functions.

source

resize_pad_collate

 resize_pad_collate (batch, max_sz=256)

*A custom collate function for a PyTorch DataLoader that performs:

  1. Resize each (image, target) pair so that the image’s maximum dimension does not exceed max_sz (using the custom ResizeMax transform).
  2. Determine the largest image height and width in the batch, round them up to a multiple of 32, and then randomly pad each image so they all share the same dimensions. BoundingBoxes and Masks in the targets are updated accordingly.
  3. Optionally perform a final Resize to ensure all images have the same shape (often a no-op if the padded size already matches).

Args: batch (List[Tuple[Image, Dict]]): A list of (image, target) pairs, where: - image can be a PIL Image, PyTorch tensor, or TorchVision tv_tensors.Image. - target is typically a dictionary containing bounding boxes ("boxes"), masks ("masks"), and possibly other metadata. max_sz (int, optional): The maximum size (height or width) for the resize step. Default: 256.

Returns: List[Tuple[Image, Dict]]: A list of (image, target) pairs where each image is resized and padded to the same dimensions, and any bounding boxes or masks have been shifted/padded accordingly.*

Function: resize_pad_collate(batch, max_sz=256)

A custom collate function designed for use in a PyTorch DataLoader. It processes a batch of (image, target) pairs by:

  1. Resizing each image (and corresponding target data) so that the maximum dimension of the image does not exceed max_sz.
    • Uses the custom ResizeMax transform (from cjm_torchvision_tfms.transforms) which maintains aspect ratio.
  2. Determining the largest height and width in the batch, rounding them up to the nearest multiple of 32, and randomly padding each image to match those dimensions.
    • Random padding is applied on all sides (top, bottom, left, right) of each image, ensuring that every image ends up with the same final size.
    • Bounding boxes (BoundingBoxes) and masks (Mask) within each target dictionary are adjusted accordingly (shifted and padded).
  3. Performing a final resize to guarantee that all images have the exact same dimensions.
    • Typically a no-op if the padding already produces the correct size but ensures consistency in shape.

Arguments

  • batch (List[Tuple[Image, dict]]):
    A list of (image, target) pairs.
    • Each image can be a PIL image, PyTorch tensor image, or a TorchVision tv_tensors.Image.
    • The target is typically a dictionary containing annotation data such as bounding boxes, masks, or other metadata. This function specifically looks for:
      • "boxes" of type tv_tensors.BoundingBoxes
      • "masks" of type tv_tensors.Mask
      • Any other keys in the dictionary are passed through unchanged.
  • max_sz (int, optional, default=256):
    The maximum dimension (width or height) to which each image will be resized.
    • The aspect ratio of each image is preserved while resizing.

Returns

  • final_pairs (List[Tuple[Image, dict]]):
    A list of (image, target) pairs where:
    1. Each image is at most max_sz in its largest dimension (before final padding).
    2. Each image is then padded (randomly on all sides) so that all images share the same height and width (rounded up to multiples of 32).
    3. The bounding boxes and masks in the target (if any) are updated to reflect the padding.
    4. A final resize ensures that each image has the same dimensions ((final_max_height, final_max_width)).

How It Works

  1. Resize to max_sz:
    Uses ResizeMax, which shrinks the image so its larger side is at most max_sz, preserving the aspect ratio.

  2. Identify maximum batch dimensions & pad:

    • Loops through all resized images to find the largest height and width.
    • Rounds them up to the nearest multiple of 32.
    • For each image, the required padding is computed. A random split is applied for top/bottom and left/right padding.
    • Corresponding bounding boxes or masks are updated to match the new image dimensions.
  3. Final resize to enforce consistent shape (if needed):

    • A transforms.Resize((final_max_height, final_max_width)) is applied to each image and target.
    • Often a no-op if the padded size already matches these dimensions, but ensures uniform shape in all images.

Usage Example

Because resize_pad_collate returns a list of (image, target) pairs, you will typically want your dataloader to yield (images, targets) as two separate structures. One way to do this is:

import torch
from torch.utils.data import DataLoader

# Suppose you have a dataset that returns (image, target) tuples
my_dataset = ...

train_sz = 256  # Example desired max size

# We wrap our custom function so that the output splits into two lists:
# (images, targets).
collate_fn = lambda batch: tuple(zip(*resize_pad_collate(batch, max_sz=train_sz)))

dataloader = DataLoader(
    my_dataset,
    batch_size=4,
    collate_fn=collate_fn
)

for images, targets in dataloader:
    # 'images' is now a tuple of resized & padded images
    # 'targets' is a tuple of corresponding dictionaries (or other target structures)
    # Each element in 'images' has consistent spatial dimensions
    # You can optionally convert them to a list or a torch.stack:
    # images = list(images)  # or images = torch.stack(images)
    
    # Proceed with training loop ...
    pass

Notes & Tips

  • Random Padding Benefit:
    Randomly distributing the padding can help reduce positional bias that might arise if padding were always placed in the same region (e.g., always on the right or bottom).

  • Why Round Dimensions to Multiples of 32?
    Many neural network architectures (especially those using stride-2 convolutions or pooling layers) often produce better or more predictable behavior when processing images of sizes that align with multiples of 32. It can also help with GPU memory management.

  • Handling Non-Dict Targets:
    If your dataset’s target is not a dictionary (or it has a different structure), you’ll need to adapt the function to properly pad and resize those objects.

  • Performance Considerations:

    • These resize and pad operations happen on the CPU. If they become a bottleneck, consider whether you can pre-process the data or move some of these transforms to the GPU.
    • Using random padding every epoch provides a mild data augmentation effect, but also increases CPU workload.

This custom collate function is designed to ensure that each sample in a batch has the same size (height and width), which is typically required for training deep learning models, while properly adjusting any bounding boxes or masks in the target data.