segmentation module¶
CustomDataset (Dataset)
¶
Custom Dataset for loading images and masks.
Source code in geoai/segmentation.py
class CustomDataset(Dataset):
"""Custom Dataset for loading images and masks."""
def __init__(
self,
images_dir: str,
masks_dir: str,
transform: A.Compose = None,
target_size: tuple = (256, 256),
num_classes: int = 2,
):
"""
Args:
images_dir (str): Directory containing images.
masks_dir (str): Directory containing masks.
transform (A.Compose, optional): Transformations to be applied on the images and masks.
target_size (tuple, optional): Target size for resizing images and masks.
num_classes (int, optional): Number of classes in the masks.
"""
self.images_dir = images_dir
self.masks_dir = masks_dir
self.transform = transform
self.target_size = target_size
self.num_classes = num_classes
self.images = sorted(os.listdir(images_dir))
self.masks = sorted(os.listdir(masks_dir))
def __len__(self) -> int:
"""Returns the total number of samples."""
return len(self.images)
def __getitem__(self, idx: int) -> dict:
"""
Args:
idx (int): Index of the sample to fetch.
Returns:
dict: A dictionary with 'pixel_values' and 'labels'.
"""
img_path = os.path.join(self.images_dir, self.images[idx])
mask_path = os.path.join(self.masks_dir, self.masks[idx])
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
image = image.resize(self.target_size)
mask = mask.resize(self.target_size)
image = np.array(image)
mask = np.array(mask)
mask = (mask > 127).astype(np.uint8)
if self.transform:
transformed = self.transform(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
assert (
mask.max() < self.num_classes
), f"Mask values should be less than {self.num_classes}, but found {mask.max()}"
assert (
mask.min() >= 0
), f"Mask values should be greater than or equal to 0, but found {mask.min()}"
mask = mask.clone().detach().long()
return {"pixel_values": image, "labels": mask}
__getitem__(self, idx)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int |
Index of the sample to fetch. |
required |
Returns:
Type | Description |
---|---|
dict |
A dictionary with 'pixel_values' and 'labels'. |
Source code in geoai/segmentation.py
def __getitem__(self, idx: int) -> dict:
"""
Args:
idx (int): Index of the sample to fetch.
Returns:
dict: A dictionary with 'pixel_values' and 'labels'.
"""
img_path = os.path.join(self.images_dir, self.images[idx])
mask_path = os.path.join(self.masks_dir, self.masks[idx])
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
image = image.resize(self.target_size)
mask = mask.resize(self.target_size)
image = np.array(image)
mask = np.array(mask)
mask = (mask > 127).astype(np.uint8)
if self.transform:
transformed = self.transform(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
assert (
mask.max() < self.num_classes
), f"Mask values should be less than {self.num_classes}, but found {mask.max()}"
assert (
mask.min() >= 0
), f"Mask values should be greater than or equal to 0, but found {mask.min()}"
mask = mask.clone().detach().long()
return {"pixel_values": image, "labels": mask}
__init__(self, images_dir, masks_dir, transform=None, target_size=(256, 256), num_classes=2)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
images_dir |
str |
Directory containing images. |
required |
masks_dir |
str |
Directory containing masks. |
required |
transform |
A.Compose |
Transformations to be applied on the images and masks. |
None |
target_size |
tuple |
Target size for resizing images and masks. |
(256, 256) |
num_classes |
int |
Number of classes in the masks. |
2 |
Source code in geoai/segmentation.py
def __init__(
self,
images_dir: str,
masks_dir: str,
transform: A.Compose = None,
target_size: tuple = (256, 256),
num_classes: int = 2,
):
"""
Args:
images_dir (str): Directory containing images.
masks_dir (str): Directory containing masks.
transform (A.Compose, optional): Transformations to be applied on the images and masks.
target_size (tuple, optional): Target size for resizing images and masks.
num_classes (int, optional): Number of classes in the masks.
"""
self.images_dir = images_dir
self.masks_dir = masks_dir
self.transform = transform
self.target_size = target_size
self.num_classes = num_classes
self.images = sorted(os.listdir(images_dir))
self.masks = sorted(os.listdir(masks_dir))
__len__(self)
special
¶
Returns the total number of samples.
Source code in geoai/segmentation.py
def __len__(self) -> int:
"""Returns the total number of samples."""
return len(self.images)
get_transform()
¶
Returns:
Type | Description |
---|---|
A.Compose |
A composition of image transformations. |
Source code in geoai/segmentation.py
def get_transform() -> A.Compose:
"""
Returns:
A.Compose: A composition of image transformations.
"""
return A.Compose(
[
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
load_model(model_path, device)
¶
Loads the fine-tuned model from the specified path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
str |
Path to the model. |
required |
device |
torch.device |
Device to load the model on. |
required |
Returns:
Type | Description |
---|---|
SegformerForSemanticSegmentation |
Loaded model. |
Source code in geoai/segmentation.py
def load_model(
model_path: str, device: torch.device
) -> SegformerForSemanticSegmentation:
"""
Loads the fine-tuned model from the specified path.
Args:
model_path (str): Path to the model.
device (torch.device): Device to load the model on.
Returns:
SegformerForSemanticSegmentation: Loaded model.
"""
model = SegformerForSemanticSegmentation.from_pretrained(model_path)
model.to(device)
model.eval()
return model
predict_image(model, image_tensor, original_size, device)
¶
Predicts the segmentation mask for the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
SegformerForSemanticSegmentation |
Fine-tuned model. |
required |
image_tensor |
torch.Tensor |
Preprocessed image tensor. |
required |
original_size |
tuple |
Original size of the image (width, height). |
required |
device |
torch.device |
Device to perform inference on. |
required |
Returns:
Type | Description |
---|---|
np.ndarray |
Predicted segmentation mask. |
Source code in geoai/segmentation.py
def predict_image(
model: SegformerForSemanticSegmentation,
image_tensor: torch.Tensor,
original_size: tuple,
device: torch.device,
) -> np.ndarray:
"""
Predicts the segmentation mask for the input image.
Args:
model (SegformerForSemanticSegmentation): Fine-tuned model.
image_tensor (torch.Tensor): Preprocessed image tensor.
original_size (tuple): Original size of the image (width, height).
device (torch.device): Device to perform inference on.
Returns:
np.ndarray: Predicted segmentation mask.
"""
with torch.no_grad():
image_tensor = image_tensor.to(device)
outputs = model(pixel_values=image_tensor)
logits = outputs.logits
upsampled_logits = F.interpolate(
logits, size=original_size[::-1], mode="bilinear", align_corners=False
)
predictions = torch.argmax(upsampled_logits, dim=1).cpu().numpy()
return predictions[0]
prepare_datasets(images_dir, masks_dir, transform, test_size=0.2, random_state=42)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
images_dir |
str |
Directory containing images. |
required |
masks_dir |
str |
Directory containing masks. |
required |
transform |
A.Compose |
Transformations to be applied. |
required |
test_size |
float |
Proportion of the dataset to include in the validation split. |
0.2 |
random_state |
int |
Random seed for shuffling the dataset. |
42 |
Returns:
Type | Description |
---|---|
tuple |
Training and validation datasets. |
Source code in geoai/segmentation.py
def prepare_datasets(
images_dir: str,
masks_dir: str,
transform: A.Compose,
test_size: float = 0.2,
random_state: int = 42,
) -> tuple:
"""
Args:
images_dir (str): Directory containing images.
masks_dir (str): Directory containing masks.
transform (A.Compose): Transformations to be applied.
test_size (float, optional): Proportion of the dataset to include in the validation split.
random_state (int, optional): Random seed for shuffling the dataset.
Returns:
tuple: Training and validation datasets.
"""
dataset = CustomDataset(images_dir, masks_dir, transform)
train_indices, val_indices = train_test_split(
list(range(len(dataset))), test_size=test_size, random_state=random_state
)
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
return train_dataset, val_dataset
preprocess_image(image_path, target_size=(256, 256))
¶
Preprocesses the input image for prediction.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_path |
str |
Path to the input image. |
required |
target_size |
tuple |
Target size for resizing the image. |
(256, 256) |
Returns:
Type | Description |
---|---|
torch.Tensor |
Preprocessed image tensor. |
Source code in geoai/segmentation.py
def preprocess_image(image_path: str, target_size: tuple = (256, 256)) -> torch.Tensor:
"""
Preprocesses the input image for prediction.
Args:
image_path (str): Path to the input image.
target_size (tuple, optional): Target size for resizing the image.
Returns:
torch.Tensor: Preprocessed image tensor.
"""
image = Image.open(image_path).convert("RGB")
transform = A.Compose(
[
A.Resize(target_size[0], target_size[1]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
image = np.array(image)
transformed = transform(image=image)
return transformed["image"].unsqueeze(0)
segment_image(image_path, model_path, target_size=(256, 256), device=device(type='cpu'))
¶
Segments the input image using the fine-tuned model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_path |
str |
Path to the input image. |
required |
model_path |
str |
Path to the fine-tuned model. |
required |
target_size |
tuple |
Target size for resizing the image. |
(256, 256) |
device |
torch.device |
Device to perform inference on. |
device(type='cpu') |
Returns:
Type | Description |
---|---|
np.ndarray |
Predicted segmentation mask. |
Source code in geoai/segmentation.py
def segment_image(
image_path: str,
model_path: str,
target_size: tuple = (256, 256),
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> np.ndarray:
"""
Segments the input image using the fine-tuned model.
Args:
image_path (str): Path to the input image.
model_path (str): Path to the fine-tuned model.
target_size (tuple, optional): Target size for resizing the image.
device (torch.device, optional): Device to perform inference on.
Returns:
np.ndarray: Predicted segmentation mask.
"""
model = load_model(model_path, device)
image = Image.open(image_path).convert("RGB")
original_size = image.size
image_tensor = preprocess_image(image_path, target_size)
predictions = predict_image(model, image_tensor, original_size, device)
return predictions
train_model(train_dataset, val_dataset, pretrained_model='nvidia/segformer-b0-finetuned-ade-512-512', model_save_path='./model', output_dir='./results', num_epochs=10, batch_size=8, learning_rate=5e-05)
¶
Trains the model and saves the fine-tuned model to the specified path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_dataset |
Dataset |
Training dataset. |
required |
val_dataset |
Dataset |
Validation dataset. |
required |
pretrained_model |
str |
Pretrained model to fine-tune. |
'nvidia/segformer-b0-finetuned-ade-512-512' |
model_save_path |
str |
Path to save the fine-tuned model. Defaults to './model'. |
'./model' |
output_dir |
str |
Directory to save training outputs. |
'./results' |
num_epochs |
int |
Number of training epochs. |
10 |
batch_size |
int |
Batch size for training and evaluation. |
8 |
learning_rate |
float |
Learning rate for training. |
5e-05 |
Returns:
Type | Description |
---|---|
str |
Path to the saved fine-tuned model. |
Source code in geoai/segmentation.py
def train_model(
train_dataset: Dataset,
val_dataset: Dataset,
pretrained_model: str = "nvidia/segformer-b0-finetuned-ade-512-512",
model_save_path: str = "./model",
output_dir: str = "./results",
num_epochs: int = 10,
batch_size: int = 8,
learning_rate: float = 5e-5,
) -> str:
"""
Trains the model and saves the fine-tuned model to the specified path.
Args:
train_dataset (Dataset): Training dataset.
val_dataset (Dataset): Validation dataset.
pretrained_model (str, optional): Pretrained model to fine-tune.
model_save_path (str): Path to save the fine-tuned model. Defaults to './model'.
output_dir (str, optional): Directory to save training outputs.
num_epochs (int, optional): Number of training epochs.
batch_size (int, optional): Batch size for training and evaluation.
learning_rate (float, optional): Learning rate for training.
Returns:
str: Path to the saved fine-tuned model.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model).to(
device
)
data_collator = DefaultDataCollator(return_tensors="pt")
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
eval_strategy="epoch",
save_strategy="epoch",
logging_dir="./logs",
learning_rate=learning_rate,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
model.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")
return model_save_path
visualize_predictions(image_path, segmented_mask, target_size=(256, 256), reference_image_path=None)
¶
Visualizes the original image, segmented mask, and optionally the reference image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_path |
str |
Path to the original image. |
required |
segmented_mask |
np.ndarray |
Predicted segmentation mask. |
required |
target_size |
tuple |
Target size for resizing images. |
(256, 256) |
reference_image_path |
str |
Path to the reference image. |
None |
Source code in geoai/segmentation.py
def visualize_predictions(
image_path: str,
segmented_mask: np.ndarray,
target_size: tuple = (256, 256),
reference_image_path: str = None,
) -> None:
"""
Visualizes the original image, segmented mask, and optionally the reference image.
Args:
image_path (str): Path to the original image.
segmented_mask (np.ndarray): Predicted segmentation mask.
target_size (tuple, optional): Target size for resizing images.
reference_image_path (str, optional): Path to the reference image.
"""
original_image = Image.open(image_path).convert("RGB")
original_image = original_image.resize(target_size)
segmented_image = Image.fromarray((segmented_mask * 255).astype(np.uint8))
if reference_image_path:
reference_image = Image.open(reference_image_path).convert("RGB")
reference_image = reference_image.resize(target_size)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[1].imshow(reference_image)
axes[1].set_title("Reference Image")
axes[1].axis("off")
else:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(original_image)
axes[0].set_title("Original Image")
axes[0].axis("off")
if reference_image_path:
axes[2].imshow(segmented_image, cmap="gray")
axes[2].set_title("Segmented Image")
axes[2].axis("off")
else:
axes[1].imshow(segmented_image, cmap="gray")
axes[1].set_title("Segmented Image")
axes[1].axis("off")
plt.tight_layout()
plt.show()