건물 영역 검출 - DeepLabV3+
2024. 4. 16. 19:53ㆍ파이썬
0. 환경설정
!pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
!pip install opencv-python
!pip install pandas
!pip install random2
!pip install seaborn
!pip install numpy>=1.20.3
!pip install albumentations
!python3 -m pip install --user albumentations
!pip install -q -U segmentation-models-pytorch albumentations> /dev/null
!pip install timm==0.9.5
!pip install pretrainedmodels==0.7.4
!git clone https://github.com/qubvel/segmentation_models.pytorch.git
1. 데이터 준비
import os
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
!unzip /content/drive/MyDrive/archive.zip -d /content/segmentation_models.pytorch/data
DATA_DIR = '/content/segmentation_models.pytorch/data/png'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'val_labels')
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')
class_names = ['background', 'building'] # select_classes
class_idx = [0,1] # select_class_indices
class_rgb_values = [[0,0,0], [255,255,255]] # select_class_rgb_values
x_sample = cv2.imread(os.path.join(x_train_dir, os.listdir(x_train_dir)[0]))
y_sample = cv2.imread(os.path.join(y_train_dir, os.listdir(y_train_dir)[0]))
plt.figure()
plt.subplot(1,2,1)
plt.title('RGB Image')
plt.imshow(cv2.cvtColor(x_sample, cv2.COLOR_BGR2RGB))
plt.subplot(1,2,2)
plt.title('Mask')
plt.imshow(cv2.cvtColor(y_sample, cv2.COLOR_BGR2RGB))
plt.show()
2. 실습에 필요한 클래스 및 함수 정의하기
%cd /content/segmentation_models.pytorch
import os, cv2
import numpy as np
import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as su
def visualize(**images):
n_images = len(images)
plt.figure(figsize=(20,8))
for idx, (name, image) in enumerate(images.items()):
plt.subplot(1, n_images, idx + 1)
plt.xticks([]);
plt.yticks([])
plt.title(name.replace('_',' ').title(), fontsize=20)
plt.imshow(image)
plt.show()
def one_hot_encode(label, label_values):
semantic_map = []
for colour in label_values:
equality = np.equal(label, colour)
class_map = np.all(equality, axis = -1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
def reverse_one_hot(image):
x = np.argmax(image, axis = -1)
return x
def colour_code_segmentation(image, label_values):
colour_codes = np.array(label_values)
x = colour_codes[image.astype(int)]
return x
class BuildingsDataset(torch.utils.data.Dataset):
def __init__(
self,
images_dir,
masks_dir,
class_rgb_values=None,
augmentation=None,
preprocessing=None,
):
self.image_paths = [os.path.join(images_dir, image_id) for image_id in sorted(os.listdir(images_dir))]
self.mask_paths = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]
self.class_rgb_values = class_rgb_values
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, i):
image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
return image, mask
def __len__(self):
return len(self.image_paths)
def get_training_augmentation():
train_transform = [
album.RandomCrop(height=256, width=256, always_apply=True),
album.OneOf(
[
album.HorizontalFlip(p=1),
album.VerticalFlip(p=1),
album.RandomRotate90(p=1),
],
p=0.75,
),
]
return album.Compose(train_transform)
def get_validation_augmentation():
# Add sufficient padding to ensure image is divisible by 32
test_transform = [
album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0),
]
return album.Compose(test_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn=None):
_transform = []
if preprocessing_fn:
_transform.append(album.Lambda(image=preprocessing_fn))
_transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
return album.Compose(_transform)
3. Training
DATA_DIR = '/content/segmentation_models.pytorch/data/png'
SAVE_WEIGHT = '/content/segmentation_models.pytorch/weight'
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION = 'sigmoid'
TRAINING = True
EPOCHS = 80
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PRETRAING_WEIGHT_PATH = None
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'val_labels')
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')
class_names = ['background', 'building'] # select_classes
class_idx = [0,1] # select_class_indices
class_rgb_values = [[0,0,0], [255,255,255]] # select_class_rgb_values
#dataset = BuildingsDataset(x_train_dir, y_train_dir, class_rgb_values=class_rgb_values)
augmented_dataset = BuildingsDataset(
x_train_dir, y_train_dir,
augmentation=get_training_augmentation(),
class_rgb_values=class_rgb_values,
)
model = smp.DeepLabV3Plus(encoder_name=ENCODER,encoder_weights=ENCODER_WEIGHTS,classes=len(CLASSES),activation=ACTIVATION)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
if PRETRAING_WEIGHT_PATH is not None and os.path.exists(PRETRAING_WEIGHT_PATH):
model = torch.load('m_buildings/deeplabv3-efficientnetb4-frontend-using-pytorch/best_model.pth', map_location=DEVICE)
train_dataset = BuildingsDataset(
x_train_dir, y_train_dir,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=class_rgb_values,
)
valid_dataset = BuildingsDataset(
x_valid_dir, y_valid_dir,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values= class_rgb_values,
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
loss = su.losses.DiceLoss()
metrics = [su.metrics.IoU(threshold=0.5),]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001),])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=5e-5,)
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
%%time
if TRAINING:
best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []
for i in range(0, EPOCHS):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
train_logs_list.append(train_logs)
valid_logs_list.append(valid_logs)
file_name = f'model_epoch_{i:03d}.pth'
torch.save(model, os.path.join(SAVE_WEIGHT, file_name))
if best_iou_score < valid_logs['iou_score']:
best_iou_score = valid_logs['iou_score']
file_name = f'best_model.pth'
torch.save(model, os.path.join(SAVE_WEIGHT, file_name))
print('Best Model Update!')
4. Validation
import os, cv2
import numpy as np
import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as su
DATA_DIR = '/content/segmentation_models.pytorch/data/png'
SAVE_WEIGHT = '/content/drive/MyDrive'
WEIGHT_FILE_NAME = '/content/drive/MyDrive/[practice7]building_segmentation.pth'
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION = 'sigmoid'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')
class_names = ['background', 'building'] # select_classes
class_idx = [0,1] # select_class_indices
class_rgb_values = [[0,0,0], [255,255,255]] # select_class_rgb_values
model = smp.DeepLabV3Plus(encoder_name=ENCODER,encoder_weights=ENCODER_WEIGHTS,classes=len(CLASSES),activation=ACTIVATION)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
if os.path.exists(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME)):
best_model = torch.load(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME), map_location=DEVICE)
print('Loaded DeepLabV3+ model from this run.')
test_dataset = BuildingsDataset(
x_test_dir,
y_test_dir,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=class_rgb_values,
)
test_dataloader = DataLoader(test_dataset)
test_epoch = su.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
valid_logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['dice_loss']:.4f}")
5. Test
import os, cv2
import numpy as np
import pandas as pd
import random
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as su
DATA_DIR = '/content/segmentation_models.pytorch/data/png'
SAVE_WEIGHT = '/content/drive/MyDrive'
WEIGHT_FILE_NAME = '/content/drive/MyDrive/[practice7]building_segmentation.pth'
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION = 'sigmoid'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')
class_names = ['background', 'building'] # select_classes
class_idx = [0,1] # select_class_indices
class_rgb_values = [[0,0,0], [255,255,255]] # select_class_rgb_values
model = smp.DeepLabV3Plus(encoder_name=ENCODER,encoder_weights=ENCODER_WEIGHTS,classes=len(CLASSES),activation=ACTIVATION)
if os.path.exists(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME)):
model = torch.load(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME), map_location=DEVICE)
print('Loaded DeepLabV3+ model from this run.')
sample_preds_folder = '/content/segmentation_models.pytorch/Untitled Folder'
if not os.path.exists(sample_preds_folder):
os.makedirs(sample_preds_folder)
# Center crop padded image / mask to original image dims
def crop_image(image, target_image_dims=[1500,1500,3]):
target_size = target_image_dims[0]
image_size = len(image)
padding = (image_size - target_size) // 2
return image[
padding:image_size - padding,
padding:image_size - padding,
:,
]
test_dataset_vis = BuildingsDataset(
x_test_dir, y_test_dir,
augmentation=get_validation_augmentation(),
class_rgb_values=class_rgb_values,
)
test_dataset = BuildingsDataset(
x_test_dir,
y_test_dir,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=class_rgb_values,
)
for idx in range(len(test_dataset)):
image, gt_mask = test_dataset[idx]
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pred_mask = model(x_tensor)
pred_mask = pred_mask.detach().squeeze().cpu().numpy()
pred_mask = np.transpose(pred_mask,(1,2,0))
pred_building_heatmap = pred_mask[:,:,class_names.index('building')]
pred_mask = crop_image(colour_code_segmentation(reverse_one_hot(pred_mask), class_rgb_values))
# Convert gt_mask from `CHW` format to `HWC` format
gt_mask = np.transpose(gt_mask,(1,2,0))
gt_mask = crop_image(colour_code_segmentation(reverse_one_hot(gt_mask), class_rgb_values))
image_vis = crop_image(test_dataset_vis[idx][0].astype('uint8'))
result = np.hstack([image_vis, gt_mask, pred_mask])[:,:,::-1]
cv2.imwrite(os.path.join(sample_preds_folder, f"sample_pred_{idx}.png"), result)
b,g,r = cv2.split(result)
result[:,:,0] = r
result[:,:,2] = b
plt.imshow(result)
plt.show()
6. Inference
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import albumentations as album
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
DATA_DIR = '/content/segmentation_models.pytorch/data/png'
SAVE_WEIGHT = '/content/drive/MyDrive'
WEIGHT_FILE_NAME = '/content/drive/MyDrive/[practice7]building_segmentation.pth'
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION = 'sigmoid'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ['background', 'building'] # select_classes
class_idx = [0,1] # select_class_indices
class_rgb_values = [[0,0,0], [255,255,255]] # select_class_rgb_values
model = smp.DeepLabV3Plus(encoder_name=ENCODER,encoder_weights=ENCODER_WEIGHTS,classes=len(CLASSES),activation=ACTIVATION)
if os.path.exists(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME)):
model = torch.load(os.path.join(SAVE_WEIGHT, WEIGHT_FILE_NAME), map_location=DEVICE)
print('Loaded DeepLabV3+ model from this run.')
print('Model Loaded!')
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')
def get_validation_augmentation():
# Add sufficient padding to ensure image is divisible by 32
test_transform = [
album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0),
]
return album.Compose(test_transform)
test_img = cv2.imread(os.path.join(x_test_dir, os.listdir(x_test_dir)[0]))
test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
augmentor = album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0)
input_image = augmentor(image=test_img)['image']
input_image = preprocessing_fn(input_image)
input_image = np.transpose(input_image,(2,0,1)).astype('float32')
input_image = torch.from_numpy(input_image).to(DEVICE).unsqueeze(0)
pred_mask = model(input_image)
pred_mask = pred_mask.detach().squeeze().cpu().numpy()
pred_mask = np.transpose(pred_mask,(1,2,0))
pred_building_heatmap = pred_mask[:,:,class_names.index('building')]
pred_mask = crop_image(colour_code_segmentation(reverse_one_hot(pred_mask), class_rgb_values)).astype('uint8')
test_img_label = cv2.imread(os.path.join(y_test_dir, os.listdir(y_test_dir)[0]))
test_img_label = crop_image(cv2.cvtColor(test_img_label, cv2.COLOR_BGR2RGB))
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.title("Image")
plt.imshow(test_img)
plt.subplot(1,3,2)
plt.title("GT MASK")
plt.imshow(test_img_label)
plt.subplot(1,3,3)
plt.title("PRED MASK")
plt.imshow(pred_mask)
plt.show()
'파이썬' 카테고리의 다른 글
png파일을 mp4로 바꿔주는 간단한 프로그램 (0) | 2024.10.21 |
---|---|
DMS (0) | 2024.04.22 |
complex-YOLO 3D object Detection on Point Clouds (0) | 2024.04.01 |
YOLOv8 도로 표지판, 신호등 검출 (0) | 2024.03.17 |
YOLO v8을 이용한 마스크 검출 (0) | 2024.03.17 |