Skip to content

segmentation module

CustomDataset

Bases: Dataset

Custom Dataset for loading images and masks.

Source code in geoai/segmentation.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class CustomDataset(Dataset):
    """Custom Dataset for loading images and masks."""

    def __init__(
        self,
        images_dir: str,
        masks_dir: str,
        transform: A.Compose = None,
        target_size: tuple = (256, 256),
        num_classes: int = 2,
    ):
        """
        Args:
            images_dir (str): Directory containing images.
            masks_dir (str): Directory containing masks.
            transform (A.Compose, optional): Transformations to be applied on the images and masks.
            target_size (tuple, optional): Target size for resizing images and masks.
            num_classes (int, optional): Number of classes in the masks.
        """
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.target_size = target_size
        self.num_classes = num_classes
        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self) -> int:
        """Returns the total number of samples."""
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        """
        Args:
            idx (int): Index of the sample to fetch.

        Returns:
            dict: A dictionary with 'pixel_values' and 'labels'.
        """
        img_path = os.path.join(self.images_dir, self.images[idx])
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = image.resize(self.target_size)
        mask = mask.resize(self.target_size)

        image = np.array(image)
        mask = np.array(mask)

        mask = (mask > 127).astype(np.uint8)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        assert (
            mask.max() < self.num_classes
        ), f"Mask values should be less than {self.num_classes}, but found {mask.max()}"
        assert (
            mask.min() >= 0
        ), f"Mask values should be greater than or equal to 0, but found {mask.min()}"

        mask = mask.clone().detach().long()

        return {"pixel_values": image, "labels": mask}

__getitem__(idx)

Parameters:

Name Type Description Default
idx int

Index of the sample to fetch.

required

Returns:

Name Type Description
dict dict

A dictionary with 'pixel_values' and 'labels'.

Source code in geoai/segmentation.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __getitem__(self, idx: int) -> dict:
    """
    Args:
        idx (int): Index of the sample to fetch.

    Returns:
        dict: A dictionary with 'pixel_values' and 'labels'.
    """
    img_path = os.path.join(self.images_dir, self.images[idx])
    mask_path = os.path.join(self.masks_dir, self.masks[idx])
    image = Image.open(img_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")

    image = image.resize(self.target_size)
    mask = mask.resize(self.target_size)

    image = np.array(image)
    mask = np.array(mask)

    mask = (mask > 127).astype(np.uint8)

    if self.transform:
        transformed = self.transform(image=image, mask=mask)
        image = transformed["image"]
        mask = transformed["mask"]

    assert (
        mask.max() < self.num_classes
    ), f"Mask values should be less than {self.num_classes}, but found {mask.max()}"
    assert (
        mask.min() >= 0
    ), f"Mask values should be greater than or equal to 0, but found {mask.min()}"

    mask = mask.clone().detach().long()

    return {"pixel_values": image, "labels": mask}

__init__(images_dir, masks_dir, transform=None, target_size=(256, 256), num_classes=2)

Parameters:

Name Type Description Default
images_dir str

Directory containing images.

required
masks_dir str

Directory containing masks.

required
transform Compose

Transformations to be applied on the images and masks.

None
target_size tuple

Target size for resizing images and masks.

(256, 256)
num_classes int

Number of classes in the masks.

2
Source code in geoai/segmentation.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    images_dir: str,
    masks_dir: str,
    transform: A.Compose = None,
    target_size: tuple = (256, 256),
    num_classes: int = 2,
):
    """
    Args:
        images_dir (str): Directory containing images.
        masks_dir (str): Directory containing masks.
        transform (A.Compose, optional): Transformations to be applied on the images and masks.
        target_size (tuple, optional): Target size for resizing images and masks.
        num_classes (int, optional): Number of classes in the masks.
    """
    self.images_dir = images_dir
    self.masks_dir = masks_dir
    self.transform = transform
    self.target_size = target_size
    self.num_classes = num_classes
    self.images = sorted(os.listdir(images_dir))
    self.masks = sorted(os.listdir(masks_dir))

__len__()

Returns the total number of samples.

Source code in geoai/segmentation.py
47
48
49
def __len__(self) -> int:
    """Returns the total number of samples."""
    return len(self.images)

get_transform()

Returns:

Type Description
Compose

A.Compose: A composition of image transformations.

Source code in geoai/segmentation.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def get_transform() -> A.Compose:
    """
    Returns:
        A.Compose: A composition of image transformations.
    """
    return A.Compose(
        [
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )

load_model(model_path, device)

Loads the fine-tuned model from the specified path.

Parameters:

Name Type Description Default
model_path str

Path to the model.

required
device device

Device to load the model on.

required

Returns:

Name Type Description
SegformerForSemanticSegmentation SegformerForSemanticSegmentation

Loaded model.

Source code in geoai/segmentation.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def load_model(
    model_path: str, device: torch.device
) -> SegformerForSemanticSegmentation:
    """
    Loads the fine-tuned model from the specified path.

    Args:
        model_path (str): Path to the model.
        device (torch.device): Device to load the model on.

    Returns:
        SegformerForSemanticSegmentation: Loaded model.
    """
    model = SegformerForSemanticSegmentation.from_pretrained(model_path)
    model.to(device)
    model.eval()
    return model

predict_image(model, image_tensor, original_size, device)

Predicts the segmentation mask for the input image.

Parameters:

Name Type Description Default
model SegformerForSemanticSegmentation

Fine-tuned model.

required
image_tensor Tensor

Preprocessed image tensor.

required
original_size tuple

Original size of the image (width, height).

required
device device

Device to perform inference on.

required

Returns:

Type Description
ndarray

np.ndarray: Predicted segmentation mask.

Source code in geoai/segmentation.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def predict_image(
    model: SegformerForSemanticSegmentation,
    image_tensor: torch.Tensor,
    original_size: tuple,
    device: torch.device,
) -> np.ndarray:
    """
    Predicts the segmentation mask for the input image.

    Args:
        model (SegformerForSemanticSegmentation): Fine-tuned model.
        image_tensor (torch.Tensor): Preprocessed image tensor.
        original_size (tuple): Original size of the image (width, height).
        device (torch.device): Device to perform inference on.

    Returns:
        np.ndarray: Predicted segmentation mask.
    """
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(pixel_values=image_tensor)
        logits = outputs.logits
        upsampled_logits = F.interpolate(
            logits, size=original_size[::-1], mode="bilinear", align_corners=False
        )
        predictions = torch.argmax(upsampled_logits, dim=1).cpu().numpy()
    return predictions[0]

prepare_datasets(images_dir, masks_dir, transform, test_size=0.2, random_state=42)

Parameters:

Name Type Description Default
images_dir str

Directory containing images.

required
masks_dir str

Directory containing masks.

required
transform Compose

Transformations to be applied.

required
test_size float

Proportion of the dataset to include in the validation split.

0.2
random_state int

Random seed for shuffling the dataset.

42

Returns:

Name Type Description
tuple tuple

Training and validation datasets.

Source code in geoai/segmentation.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def prepare_datasets(
    images_dir: str,
    masks_dir: str,
    transform: A.Compose,
    test_size: float = 0.2,
    random_state: int = 42,
) -> tuple:
    """
    Args:
        images_dir (str): Directory containing images.
        masks_dir (str): Directory containing masks.
        transform (A.Compose): Transformations to be applied.
        test_size (float, optional): Proportion of the dataset to include in the validation split.
        random_state (int, optional): Random seed for shuffling the dataset.

    Returns:
        tuple: Training and validation datasets.
    """
    dataset = CustomDataset(images_dir, masks_dir, transform)
    train_indices, val_indices = train_test_split(
        list(range(len(dataset))), test_size=test_size, random_state=random_state
    )
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    return train_dataset, val_dataset

preprocess_image(image_path, target_size=(256, 256))

Preprocesses the input image for prediction.

Parameters:

Name Type Description Default
image_path str

Path to the input image.

required
target_size tuple

Target size for resizing the image.

(256, 256)

Returns:

Type Description
Tensor

torch.Tensor: Preprocessed image tensor.

Source code in geoai/segmentation.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def preprocess_image(image_path: str, target_size: tuple = (256, 256)) -> torch.Tensor:
    """
    Preprocesses the input image for prediction.

    Args:
        image_path (str): Path to the input image.
        target_size (tuple, optional): Target size for resizing the image.

    Returns:
        torch.Tensor: Preprocessed image tensor.
    """
    image = Image.open(image_path).convert("RGB")
    transform = A.Compose(
        [
            A.Resize(target_size[0], target_size[1]),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )
    image = np.array(image)
    transformed = transform(image=image)
    return transformed["image"].unsqueeze(0)

segment_image(image_path, model_path, target_size=(256, 256), device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

Segments the input image using the fine-tuned model.

Parameters:

Name Type Description Default
image_path str

Path to the input image.

required
model_path str

Path to the fine-tuned model.

required
target_size tuple

Target size for resizing the image.

(256, 256)
device device

Device to perform inference on.

device('cuda' if is_available() else 'cpu')

Returns:

Type Description
ndarray

np.ndarray: Predicted segmentation mask.

Source code in geoai/segmentation.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def segment_image(
    image_path: str,
    model_path: str,
    target_size: tuple = (256, 256),
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> np.ndarray:
    """
    Segments the input image using the fine-tuned model.

    Args:
        image_path (str): Path to the input image.
        model_path (str): Path to the fine-tuned model.
        target_size (tuple, optional): Target size for resizing the image.
        device (torch.device, optional): Device to perform inference on.

    Returns:
        np.ndarray: Predicted segmentation mask.
    """
    model = load_model(model_path, device)
    image = Image.open(image_path).convert("RGB")
    original_size = image.size
    image_tensor = preprocess_image(image_path, target_size)
    predictions = predict_image(model, image_tensor, original_size, device)
    return predictions

train_model(train_dataset, val_dataset, pretrained_model='nvidia/segformer-b0-finetuned-ade-512-512', model_save_path='./model', output_dir='./results', num_epochs=10, batch_size=8, learning_rate=5e-05)

Trains the model and saves the fine-tuned model to the specified path.

Parameters:

Name Type Description Default
train_dataset Dataset

Training dataset.

required
val_dataset Dataset

Validation dataset.

required
pretrained_model str

Pretrained model to fine-tune.

'nvidia/segformer-b0-finetuned-ade-512-512'
model_save_path str

Path to save the fine-tuned model. Defaults to './model'.

'./model'
output_dir str

Directory to save training outputs.

'./results'
num_epochs int

Number of training epochs.

10
batch_size int

Batch size for training and evaluation.

8
learning_rate float

Learning rate for training.

5e-05

Returns:

Name Type Description
str str

Path to the saved fine-tuned model.

Source code in geoai/segmentation.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def train_model(
    train_dataset: Dataset,
    val_dataset: Dataset,
    pretrained_model: str = "nvidia/segformer-b0-finetuned-ade-512-512",
    model_save_path: str = "./model",
    output_dir: str = "./results",
    num_epochs: int = 10,
    batch_size: int = 8,
    learning_rate: float = 5e-5,
) -> str:
    """
    Trains the model and saves the fine-tuned model to the specified path.

    Args:
        train_dataset (Dataset): Training dataset.
        val_dataset (Dataset): Validation dataset.
        pretrained_model (str, optional): Pretrained model to fine-tune.
        model_save_path (str): Path to save the fine-tuned model. Defaults to './model'.
        output_dir (str, optional): Directory to save training outputs.
        num_epochs (int, optional): Number of training epochs.
        batch_size (int, optional): Batch size for training and evaluation.
        learning_rate (float, optional): Learning rate for training.

    Returns:
        str: Path to the saved fine-tuned model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model).to(
        device
    )
    data_collator = DefaultDataCollator(return_tensors="pt")

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        learning_rate=learning_rate,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )

    trainer.train()
    model.save_pretrained(model_save_path)
    print(f"Model saved to {model_save_path}")
    return model_save_path

visualize_predictions(image_path, segmented_mask, target_size=(256, 256), reference_image_path=None)

Visualizes the original image, segmented mask, and optionally the reference image.

Parameters:

Name Type Description Default
image_path str

Path to the original image.

required
segmented_mask ndarray

Predicted segmentation mask.

required
target_size tuple

Target size for resizing images.

(256, 256)
reference_image_path str

Path to the reference image.

None
Source code in geoai/segmentation.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def visualize_predictions(
    image_path: str,
    segmented_mask: np.ndarray,
    target_size: tuple = (256, 256),
    reference_image_path: str = None,
) -> None:
    """
    Visualizes the original image, segmented mask, and optionally the reference image.

    Args:
        image_path (str): Path to the original image.
        segmented_mask (np.ndarray): Predicted segmentation mask.
        target_size (tuple, optional): Target size for resizing images.
        reference_image_path (str, optional): Path to the reference image.
    """
    original_image = Image.open(image_path).convert("RGB")
    original_image = original_image.resize(target_size)
    segmented_image = Image.fromarray((segmented_mask * 255).astype(np.uint8))

    if reference_image_path:
        reference_image = Image.open(reference_image_path).convert("RGB")
        reference_image = reference_image.resize(target_size)
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        axes[1].imshow(reference_image)
        axes[1].set_title("Reference Image")
        axes[1].axis("off")
    else:
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(original_image)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    if reference_image_path:
        axes[2].imshow(segmented_image, cmap="gray")
        axes[2].set_title("Segmented Image")
        axes[2].axis("off")
    else:
        axes[1].imshow(segmented_image, cmap="gray")
        axes[1].set_title("Segmented Image")
        axes[1].axis("off")

    plt.tight_layout()
    plt.show()