cjm-pytorch-utils

Some utility functions for working with PyTorch.

Install

pip install cjm_pytorch_utils

How to use

set_seed

from cjm_pytorch_utils.core import set_seed
seed = 1234
set_seed(seed)

pil_to_tensor

from cjm_pytorch_utils.core import pil_to_tensor
from PIL import Image
from torchvision import transforms
img_path = img_path = '../images/cat.jpg'
src_img = Image.open(img_path).convert('RGB')
print(f"Source Image Size: {src_img.size}")

img_tensor = pil_to_tensor(src_img, [0.5], [0.5])
img_tensor.shape, img_tensor.min(), img_tensor.max()
Source Image Size: (768, 512)
(torch.Size([1, 3, 512, 768]), tensor(-1.), tensor(1.))

tensor_to_pil

from cjm_pytorch_utils.core import tensor_to_pil
tensor_img = tensor_to_pil(transforms.ToTensor()(src_img))
tensor_img

iterate_modules

from cjm_pytorch_utils.core import iterate_modules
import torch
from torchvision import models
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

for index, module in enumerate(iterate_modules(vgg)):
    if type(module) == torch.nn.modules.activation.ReLU:
        print(f"{index}: {module}")
1: ReLU(inplace=True)
3: ReLU(inplace=True)
6: ReLU(inplace=True)
8: ReLU(inplace=True)
11: ReLU(inplace=True)
13: ReLU(inplace=True)
15: ReLU(inplace=True)
18: ReLU(inplace=True)
20: ReLU(inplace=True)
22: ReLU(inplace=True)
25: ReLU(inplace=True)
27: ReLU(inplace=True)
29: ReLU(inplace=True)

tensor_stats_df

from cjm_pytorch_utils.core import tensor_stats_df
tensor_stats_df(torch.randn(1, 3, 256, 256))
0
mean 0.003342
std 0.99868
min -4.558271
max 4.815985
shape (1, 3, 256, 256)

get_torch_device

from cjm_pytorch_utils.core import get_torch_device
get_torch_device()
'cuda'

denorm_img_tensor

from cjm_pytorch_utils.core import denorm_img_tensor
tensor_to_pil(img_tensor)

tensor_to_pil(denorm_img_tensor(img_tensor, [0.5], [0.5]))