ColorJitter in PyTorch (2)

Buy Me a Coffee☕ *Memos: My post explains ColorJitter() about brightness argument. My post explains ColorJitter() about saturation argument. My post explains ColorJitter() about hue argument. My post explains OxfordIIITPet(). ColorJitter() can randomly change the brightness, contrast, saturation and hue of an image as shown below: from torchvision.datasets import OxfordIIITPet from torchvision.transforms.v2 import ColorJitter origin_data = OxfordIIITPet( root="data", transform=None ) contrast1_1origin_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[1, 1]) ) contrast0_5_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 5]) # transform=ColorJitter(contrast=4) ) contrast0_1_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 1]) ) contrast1_5_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[1, 5]) ) contrast08_08_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.8, 0.8]) ) contrast06_06_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.6, 0.6]) ) contrast04_04_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.4, 0.4]) ) contrast02_02_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.2, 0.2]) ) contrast0_0_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 0]) ) contrast2_2_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[2, 2]) ) contrast4_4_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[4, 4]) ) contrast8_8_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[8, 8]) ) contrast16_16_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[16, 16]) ) contrast50_50_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[50, 50]) ) import matplotlib.pyplot as plt def show_images1(data, main_title=None): plt.figure(figsize=[10, 5]) plt.suptitle(t=main_title, y=0.8, fontsize=14) for i, (im, _) in zip(range(1, 6), data): plt.subplot(1, 5, i) plt.imshow(X=im) plt.xticks(ticks=[]) plt.yticks(ticks=[]) plt.tight_layout() plt.show() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast0_5_data, main_title="contrast0_5_data") show_images1(data=contrast0_1_data, main_title="contrast0_1_data") show_images1(data=contrast1_5_data, main_title="contrast1_5_data") print() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast08_08_data, main_title="contrast08_08_data") show_images1(data=contrast06_06_data, main_title="contrast06_06_data") show_images1(data=contrast04_04_data, main_title="contrast04_04_data") show_images1(data=contrast02_02_data, main_title="contrast02_02_data") show_images1(data=contrast0_0_data, main_title="contrast0_0_data") print() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast2_2_data, main_title="contrast2_2_data") show_images1(data=contrast4_4_data, main_title="contrast4_4_data") show_images1(data=contrast8_8_data, main_title="contrast8_8_data") show_images1(data=contrast16_16_data, main_title="contrast16_16_data") show_images1(data=contrast50_50_data, main_title="contrast50_50_data") # ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓ def show_images2(data, main_title=None, b=0, c=0, s=0, h=0): plt.figure(figsize=[10, 5]) plt.suptitle(t=main_title, y=0.8, fontsize=14) for i, (im, _) in zip(range(1, 6), data): plt.subplot(1, 5, i) cj = ColorJitter(brightness=b, contrast=c, # Here saturation=s, hue=h) plt.imshow(X=cj(im)) # Here plt.xticks(ticks=[]) plt.yticks(ticks=[]) plt.tight_layout() plt.show() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast0_5_data", c=[0, 5]) # ↑ show_images2(data=origin_data, main_title="contrast4_data", c=4) show_images2(data=origin_data, main_title="contrast0_1_data", c=[0, 1]) show_images2(data=origin_data, main_title="contrast1_5_data", c=[1, 5]) print() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast08_08_data", c=[0.8, 0.8]) show_images2(data=origin_data, main_title="contrast06_06_data", c=[0.6, 0.6]) show_images2(data=origin_data, main_title="contrast04_04_data", c=[0.4, 0.4]) show_images2(data=origin_data, main_title="contrast02_02_data", c=[0.2, 0.2]) show_images2(data=origin_data, main_title="contrast0_0_data", c=[0, 0]) print() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast2_2_data", c

Feb 18, 2025 - 12:12
 0
ColorJitter in PyTorch (2)

Buy Me a Coffee

*Memos:

ColorJitter() can randomly change the brightness, contrast, saturation and hue of an image as shown below:

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import ColorJitter

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

contrast1_1origin_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[1, 1])
)

contrast0_5_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0, 5])
    # transform=ColorJitter(contrast=4)
)

contrast0_1_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0, 1])
)

contrast1_5_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[1, 5])
)

contrast08_08_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0.8, 0.8])
)

contrast06_06_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0.6, 0.6])
)

contrast04_04_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0.4, 0.4])
)

contrast02_02_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0.2, 0.2])
)

contrast0_0_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[0, 0])
)

contrast2_2_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[2, 2])
)

contrast4_4_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[4, 4])
)

contrast8_8_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[8, 8])
)

contrast16_16_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[16, 16])
)

contrast50_50_data = OxfordIIITPet(
    root="data",
    transform=ColorJitter(contrast=[50, 50])
)

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=contrast1_1origin_data,
             main_title="contrast1_1origin_data")
show_images1(data=contrast0_5_data, main_title="contrast0_5_data")
show_images1(data=contrast0_1_data, main_title="contrast0_1_data")
show_images1(data=contrast1_5_data, main_title="contrast1_5_data")
print()
show_images1(data=contrast1_1origin_data,
             main_title="contrast1_1origin_data")
show_images1(data=contrast08_08_data, main_title="contrast08_08_data")
show_images1(data=contrast06_06_data, main_title="contrast06_06_data")
show_images1(data=contrast04_04_data, main_title="contrast04_04_data")
show_images1(data=contrast02_02_data, main_title="contrast02_02_data")
show_images1(data=contrast0_0_data, main_title="contrast0_0_data")
print()
show_images1(data=contrast1_1origin_data,
             main_title="contrast1_1origin_data")
show_images1(data=contrast2_2_data, main_title="contrast2_2_data")
show_images1(data=contrast4_4_data, main_title="contrast4_4_data")
show_images1(data=contrast8_8_data, main_title="contrast8_8_data")
show_images1(data=contrast16_16_data, main_title="contrast16_16_data")
show_images1(data=contrast50_50_data, main_title="contrast50_50_data")

# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, b=0, c=0, s=0, h=0):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        cj = ColorJitter(brightness=b, contrast=c, # Here
                         saturation=s, hue=h)
        plt.imshow(X=cj(im)) # Here
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast0_5_data", c=[0, 5])
# ↑ show_images2(data=origin_data, main_title="contrast4_data", c=4)
show_images2(data=origin_data, main_title="contrast0_1_data", c=[0, 1])
show_images2(data=origin_data, main_title="contrast1_5_data", c=[1, 5])
print()
show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast08_08_data", c=[0.8, 0.8])
show_images2(data=origin_data, main_title="contrast06_06_data", c=[0.6, 0.6])
show_images2(data=origin_data, main_title="contrast04_04_data", c=[0.4, 0.4])
show_images2(data=origin_data, main_title="contrast02_02_data", c=[0.2, 0.2])
show_images2(data=origin_data, main_title="contrast0_0_data", c=[0, 0])
print()
show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast2_2_data", c=[2, 2])
show_images2(data=origin_data, main_title="contrast4_4_data", c=[4, 4])
show_images2(data=origin_data, main_title="contrast8_8_data", c=[8, 8])
show_images2(data=origin_data, main_title="contrast16_16_data", c=[16, 16])
show_images2(data=origin_data, main_title="contrast50_50_data", c=[50, 50])

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description

Image description