Train a Semantic Segmentation Model using Segmentation-Models-PyTorch¶
This notebook demonstrates how to train semantic segmentation models for object detection (e.g., building detection) using the segmentation-models-pytorch library. Unlike instance segmentation with Mask R-CNN, this approach treats the task as pixel-level binary classification.
Install packages¶
To use the new functionality, ensure the required packages are installed.
# %pip install geoai-py
Import libraries¶
import geoai
Download sample data¶
We'll use the same dataset as the Mask R-CNN example for consistency.
train_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif"
)
train_vector_url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson"
test_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_test.tif"
)
train_raster_path = geoai.download_file(train_raster_url)
train_vector_path = geoai.download_file(train_vector_url)
test_raster_path = geoai.download_file(test_raster_url)
Visualize sample data¶
geoai.view_vector_interactive(train_vector_path, tiles=train_raster_url)
geoai.view_raster(test_raster_url)
Create training data¶
We'll create the same training tiles as before.
out_folder = "buildings"
tiles = geoai.export_geotiff_tiles(
in_raster=train_raster_path,
out_folder=out_folder,
in_class_data=train_vector_path,
tile_size=512,
stride=256,
buffer_radius=0,
)
Train semantic segmentation model¶
Now we'll train a semantic segmentation model using the new train_segmentation_model
function. This function supports various architectures from segmentation-models-pytorch
:
- Architectures:
unet
,unetplusplus
deeplabv3
,deeplabv3plus
,fpn
,pspnet
,linknet
,manet
- Encoders:
resnet34
,resnet50
,efficientnet-b0
,mobilenet_v2
, etc.
For more details, please refer to the segmentation-models-pytorch documentation.
Example 1: U-Net with ResNet34 encoder¶
# Train U-Net model
geoai.train_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/unet_models",
architecture="unet",
encoder_name="resnet34",
encoder_weights="imagenet",
num_channels=3,
num_classes=2, # background and building
batch_size=8,
num_epochs=50,
learning_rate=0.001,
val_split=0.2,
verbose=True,
)
Example 2: SegFormer with resnet152 encoder¶
geoai.train_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/segformer_models",
architecture="segformer",
encoder_name="resnet152",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
batch_size=6, # Smaller batch size for more complex model
num_epochs=50,
learning_rate=0.0005,
val_split=0.2,
)
Run inference¶
Now we'll use the trained model to make predictions on the test image.
# Define paths
masks_path = "naip_test_semantic_prediction.tif"
model_path = f"{out_folder}/unet_models/best_model.pth"
# Run semantic segmentation inference
geoai.semantic_segmentation(
input_path=test_raster_path,
output_path=masks_path,
model_path=model_path,
architecture="unet",
encoder_name="resnet34",
num_channels=3,
num_classes=2,
window_size=512,
overlap=256,
batch_size=4,
)
Vectorize masks¶
Convert the predicted mask to vector format for better visualization and analysis.
output_vector_path = "naip_test_semantic_prediction.geojson"
gdf = geoai.orthogonalize(masks_path, output_vector_path, epsilon=2)
Visualize results¶
geoai.view_vector_interactive(output_vector_path, tiles=test_raster_url)
geoai.create_split_map(
left_layer=output_vector_path,
right_layer=test_raster_url,
left_args={"style": {"color": "red", "fillOpacity": 0.2}},
basemap=test_raster_url,
)
Model Performance Analysis¶
Let's examine the training curves and model performance:
# Load and display training history
import torch
import matplotlib.pyplot as plt
history_path = f"{out_folder}/unet_models/training_history.pth"
history = torch.load(history_path)
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(history["train_losses"], label="Train Loss")
plt.plot(history["val_losses"], label="Val Loss")
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(history["val_ious"], label="Val IoU")
plt.title("IoU Score")
plt.xlabel("Epoch")
plt.ylabel("IoU")
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(history["val_dices"], label="Val Dice")
plt.title("Dice Score")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
print(f"Best IoU: {max(history['val_ious']):.4f}")
print(f"Best Dice: {max(history['val_dices']):.4f}")
print(f"Final IoU: {history['val_ious'][-1]:.4f}")
print(f"Final Dice: {history['val_dices'][-1]:.4f}")