TrivialAugmentWide in PyTorch

Buy Me a Coffee☕ *Memos: My post explains RandAugment() about num_ops and fill argument. My post explains AutoAugment(). My post explains AugMix() about no arguments and full argument. My post explains OxfordIIITPet(). TrivialAugmentWide() can randomly trivial-augment an image as shown below: *Memos: The 1st argument for initialization is num_magnitude_bins(Optional-Default:31-Type:int). *It must be 1

Mar 16, 2025 - 14:37
 0
TrivialAugmentWide in PyTorch

Buy Me a Coffee

*Memos:

TrivialAugmentWide() can randomly trivial-augment an image as shown below:

*Memos:

  • The 1st argument for initialization is num_magnitude_bins(Optional-Default:31-Type:int). *It must be 1 <= x.
  • The 2nd argument for initialization is interpolation(Optional-Default:InterpolationMode.NEAREST-Type:InterpolationMode). *If the input is a tensor, only InterpolationMode.NEAREST, InterpolationMode.BILINEAR can be set to it.
  • The 3rd argument for initialization is fill(Optional-Default:None-Type:int, float or tuple/list(int or float)): *Memos:
    • It can change the background of an image. *The background can be seen when trivial-augmenting an image.
    • A tuple/list must be the 1D with 1 or 3 elements.
    • If all values are x <= 0, it's black.
    • If all values are 255 <= x, it's white.
  • The 1st argument is img(Required-Type:PIL Image or tensor(int)): *Memos:
    • A tensor must be 3D.
    • Don't use img=.
  • v2 is recommended to use according to V1 or V2? Which one should I use?.
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import TrivialAugmentWide
from torchvision.transforms.functional import InterpolationMode

taw = TrivialAugmentWide()
taw = TrivialAugmentWide(num_magnitude_bins = 31,
                         interpolation = InterpolationMode.NEAREST, 
                         fill= None)
taw
# TrivialAugmentWide(interpolation=InterpolationMode.NEAREST,
#                    num_magnitude_bins=31)

taw.num_magnitude_bins
# 31

taw.interpolation
# 

print(taw.fill)
# None

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

nmb1_data = OxfordIIITPet( # `nmb` is num_magnitude_bins.
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=1)
)

nmb2_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=2)
)

nmb5_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=5)
)

nmb10_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=10)
)

nmb25_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=25)
)

nmb50_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=50)
)

nmb100_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=100)
)

nmb500_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=500)
)

nmb1000_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=1000)
)

nmb10fgray_data = OxfordIIITPet( # `f` is fill.
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=10, fill=150)
    # transform=TrivialAugmentWide(num_magnitude_bins=10, fill=[150])
)

nmb10fpurple_data = OxfordIIITPet(
    root="data",
    transform=TrivialAugmentWide(num_magnitude_bins=10, fill=[160, 32, 240])
)

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=origin_data, main_title="origin_data")
print()
show_images1(data=nmb1_data, main_title="nmb1_data")
show_images1(data=nmb2_data, main_title="nmb2_data")
show_images1(data=nmb5_data, main_title="nmb5_data")
show_images1(data=nmb10_data, main_title="nmb10_data")
show_images1(data=nmb25_data, main_title="nmb25_data")
show_images1(data=nmb50_data, main_title="nmb50_data")
show_images1(data=nmb100_data, main_title="nmb100_data")
show_images1(data=nmb500_data, main_title="nmb500_data")
show_images1(data=nmb1000_data, main_title="nmb1000_data")
print()
show_images1(data=nmb10fgray_data, main_title="nmb10fgray_data")
show_images1(data=nmb10fpurple_data, main_title="nmb10fpurple_data")

# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, nmb=31,
                 ip=InterpolationMode.NEAREST, f=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    if main_title != "origin_data":
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            ra = TrivialAugmentWide(num_magnitude_bins=nmb,
                                    interpolation=ip, fill=f)
            plt.imshow(X=ra(im))
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    else:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            plt.imshow(X=im)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images2(data=origin_data, main_title="origin_data")
print()
show_images2(data=origin_data, main_title="nmb1_data", nmb=1)
show_images2(data=origin_data, main_title="nmb2_data", nmb=2)
show_images2(data=origin_data, main_title="nmb5_data", nmb=5)
show_images2(data=origin_data, main_title="nmb10_data", nmb=10)
show_images2(data=origin_data, main_title="nmb25_data", nmb=25)
show_images2(data=origin_data, main_title="nmb50_data", nmb=50)
show_images2(data=origin_data, main_title="nmb100_data", nmb=100)
show_images2(data=origin_data, main_title="nmb500_data", nmb=500)
show_images2(data=origin_data, main_title="nmb1000_data", nmb=1000)
print()
show_images2(data=origin_data, main_title="nmb10fgray_data", nmb=10, f=150)
show_images2(data=origin_data, main_title="nmb10fpurple_data", nmb=10,
             f=[160, 32, 240])

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description