ColorJitter in PyTorch (2)
Buy Me a Coffee☕ *Memos: My post explains ColorJitter() about brightness argument. My post explains ColorJitter() about saturation argument. My post explains ColorJitter() about hue argument. My post explains OxfordIIITPet(). ColorJitter() can randomly change the brightness, contrast, saturation and hue of an image as shown below: from torchvision.datasets import OxfordIIITPet from torchvision.transforms.v2 import ColorJitter origin_data = OxfordIIITPet( root="data", transform=None ) contrast1_1origin_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[1, 1]) ) contrast0_5_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 5]) # transform=ColorJitter(contrast=4) ) contrast0_1_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 1]) ) contrast1_5_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[1, 5]) ) contrast08_08_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.8, 0.8]) ) contrast06_06_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.6, 0.6]) ) contrast04_04_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.4, 0.4]) ) contrast02_02_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0.2, 0.2]) ) contrast0_0_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[0, 0]) ) contrast2_2_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[2, 2]) ) contrast4_4_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[4, 4]) ) contrast8_8_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[8, 8]) ) contrast16_16_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[16, 16]) ) contrast50_50_data = OxfordIIITPet( root="data", transform=ColorJitter(contrast=[50, 50]) ) import matplotlib.pyplot as plt def show_images1(data, main_title=None): plt.figure(figsize=[10, 5]) plt.suptitle(t=main_title, y=0.8, fontsize=14) for i, (im, _) in zip(range(1, 6), data): plt.subplot(1, 5, i) plt.imshow(X=im) plt.xticks(ticks=[]) plt.yticks(ticks=[]) plt.tight_layout() plt.show() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast0_5_data, main_title="contrast0_5_data") show_images1(data=contrast0_1_data, main_title="contrast0_1_data") show_images1(data=contrast1_5_data, main_title="contrast1_5_data") print() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast08_08_data, main_title="contrast08_08_data") show_images1(data=contrast06_06_data, main_title="contrast06_06_data") show_images1(data=contrast04_04_data, main_title="contrast04_04_data") show_images1(data=contrast02_02_data, main_title="contrast02_02_data") show_images1(data=contrast0_0_data, main_title="contrast0_0_data") print() show_images1(data=contrast1_1origin_data, main_title="contrast1_1origin_data") show_images1(data=contrast2_2_data, main_title="contrast2_2_data") show_images1(data=contrast4_4_data, main_title="contrast4_4_data") show_images1(data=contrast8_8_data, main_title="contrast8_8_data") show_images1(data=contrast16_16_data, main_title="contrast16_16_data") show_images1(data=contrast50_50_data, main_title="contrast50_50_data") # ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓ def show_images2(data, main_title=None, b=0, c=0, s=0, h=0): plt.figure(figsize=[10, 5]) plt.suptitle(t=main_title, y=0.8, fontsize=14) for i, (im, _) in zip(range(1, 6), data): plt.subplot(1, 5, i) cj = ColorJitter(brightness=b, contrast=c, # Here saturation=s, hue=h) plt.imshow(X=cj(im)) # Here plt.xticks(ticks=[]) plt.yticks(ticks=[]) plt.tight_layout() plt.show() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast0_5_data", c=[0, 5]) # ↑ show_images2(data=origin_data, main_title="contrast4_data", c=4) show_images2(data=origin_data, main_title="contrast0_1_data", c=[0, 1]) show_images2(data=origin_data, main_title="contrast1_5_data", c=[1, 5]) print() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast08_08_data", c=[0.8, 0.8]) show_images2(data=origin_data, main_title="contrast06_06_data", c=[0.6, 0.6]) show_images2(data=origin_data, main_title="contrast04_04_data", c=[0.4, 0.4]) show_images2(data=origin_data, main_title="contrast02_02_data", c=[0.2, 0.2]) show_images2(data=origin_data, main_title="contrast0_0_data", c=[0, 0]) print() show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1]) show_images2(data=origin_data, main_title="contrast2_2_data", c

*Memos:
-
My post explains ColorJitter() about
brightness
argument. -
My post explains ColorJitter() about
saturation
argument. -
My post explains ColorJitter() about
hue
argument. - My post explains OxfordIIITPet().
ColorJitter() can randomly change the brightness, contrast, saturation and hue of an image as shown below:
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import ColorJitter
origin_data = OxfordIIITPet(
root="data",
transform=None
)
contrast1_1origin_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[1, 1])
)
contrast0_5_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0, 5])
# transform=ColorJitter(contrast=4)
)
contrast0_1_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0, 1])
)
contrast1_5_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[1, 5])
)
contrast08_08_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0.8, 0.8])
)
contrast06_06_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0.6, 0.6])
)
contrast04_04_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0.4, 0.4])
)
contrast02_02_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0.2, 0.2])
)
contrast0_0_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[0, 0])
)
contrast2_2_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[2, 2])
)
contrast4_4_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[4, 4])
)
contrast8_8_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[8, 8])
)
contrast16_16_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[16, 16])
)
contrast50_50_data = OxfordIIITPet(
root="data",
transform=ColorJitter(contrast=[50, 50])
)
import matplotlib.pyplot as plt
def show_images1(data, main_title=None):
plt.figure(figsize=[10, 5])
plt.suptitle(t=main_title, y=0.8, fontsize=14)
for i, (im, _) in zip(range(1, 6), data):
plt.subplot(1, 5, i)
plt.imshow(X=im)
plt.xticks(ticks=[])
plt.yticks(ticks=[])
plt.tight_layout()
plt.show()
show_images1(data=contrast1_1origin_data,
main_title="contrast1_1origin_data")
show_images1(data=contrast0_5_data, main_title="contrast0_5_data")
show_images1(data=contrast0_1_data, main_title="contrast0_1_data")
show_images1(data=contrast1_5_data, main_title="contrast1_5_data")
print()
show_images1(data=contrast1_1origin_data,
main_title="contrast1_1origin_data")
show_images1(data=contrast08_08_data, main_title="contrast08_08_data")
show_images1(data=contrast06_06_data, main_title="contrast06_06_data")
show_images1(data=contrast04_04_data, main_title="contrast04_04_data")
show_images1(data=contrast02_02_data, main_title="contrast02_02_data")
show_images1(data=contrast0_0_data, main_title="contrast0_0_data")
print()
show_images1(data=contrast1_1origin_data,
main_title="contrast1_1origin_data")
show_images1(data=contrast2_2_data, main_title="contrast2_2_data")
show_images1(data=contrast4_4_data, main_title="contrast4_4_data")
show_images1(data=contrast8_8_data, main_title="contrast8_8_data")
show_images1(data=contrast16_16_data, main_title="contrast16_16_data")
show_images1(data=contrast50_50_data, main_title="contrast50_50_data")
# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, b=0, c=0, s=0, h=0):
plt.figure(figsize=[10, 5])
plt.suptitle(t=main_title, y=0.8, fontsize=14)
for i, (im, _) in zip(range(1, 6), data):
plt.subplot(1, 5, i)
cj = ColorJitter(brightness=b, contrast=c, # Here
saturation=s, hue=h)
plt.imshow(X=cj(im)) # Here
plt.xticks(ticks=[])
plt.yticks(ticks=[])
plt.tight_layout()
plt.show()
show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast0_5_data", c=[0, 5])
# ↑ show_images2(data=origin_data, main_title="contrast4_data", c=4)
show_images2(data=origin_data, main_title="contrast0_1_data", c=[0, 1])
show_images2(data=origin_data, main_title="contrast1_5_data", c=[1, 5])
print()
show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast08_08_data", c=[0.8, 0.8])
show_images2(data=origin_data, main_title="contrast06_06_data", c=[0.6, 0.6])
show_images2(data=origin_data, main_title="contrast04_04_data", c=[0.4, 0.4])
show_images2(data=origin_data, main_title="contrast02_02_data", c=[0.2, 0.2])
show_images2(data=origin_data, main_title="contrast0_0_data", c=[0, 0])
print()
show_images2(data=origin_data, main_title="contrast1_1origin_data", c=[1, 1])
show_images2(data=origin_data, main_title="contrast2_2_data", c=[2, 2])
show_images2(data=origin_data, main_title="contrast4_4_data", c=[4, 4])
show_images2(data=origin_data, main_title="contrast8_8_data", c=[8, 8])
show_images2(data=origin_data, main_title="contrast16_16_data", c=[16, 16])
show_images2(data=origin_data, main_title="contrast50_50_data", c=[50, 50])