Semantic segmentation
In [ ]:
Copied!
import os
import torch
from matplotlib import pyplot as plt
from rastervision.core.data import ClassConfig, SemanticSegmentationLabels
import albumentations as A
from rastervision.pytorch_learner import (
SemanticSegmentationRandomWindowGeoDataset,
SemanticSegmentationSlidingWindowGeoDataset,
SemanticSegmentationVisualizer,
SemanticSegmentationGeoDataConfig,
SemanticSegmentationLearnerConfig,
SolverConfig,
SemanticSegmentationLearner,
)
import os
import torch
from matplotlib import pyplot as plt
from rastervision.core.data import ClassConfig, SemanticSegmentationLabels
import albumentations as A
from rastervision.pytorch_learner import (
SemanticSegmentationRandomWindowGeoDataset,
SemanticSegmentationSlidingWindowGeoDataset,
SemanticSegmentationVisualizer,
SemanticSegmentationGeoDataConfig,
SemanticSegmentationLearnerConfig,
SolverConfig,
SemanticSegmentationLearner,
)
In [ ]:
Copied!
os.environ["AWS_NO_SIGN_REQUEST"] = "YES"
os.environ["AWS_NO_SIGN_REQUEST"] = "YES"
In [ ]:
Copied!
class_config = ClassConfig(
names=["background", "building"],
colors=["lightgray", "darkred"],
null_class="background",
)
viz = SemanticSegmentationVisualizer(
class_names=class_config.names, class_colors=class_config.colors
)
class_config = ClassConfig(
names=["background", "building"],
colors=["lightgray", "darkred"],
null_class="background",
)
viz = SemanticSegmentationVisualizer(
class_names=class_config.names, class_colors=class_config.colors
)
In [ ]:
Copied!
train_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif"
train_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson"
val_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif"
val_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson"
train_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif"
train_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson"
val_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif"
val_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson"
In [ ]:
Copied!
pred_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13.tif"
pred_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson"
pred_image_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13.tif"
pred_label_uri = "s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson"
In [ ]:
Copied!
data_augmentation_transform = A.Compose(
[
A.Flip(),
A.ShiftScaleRotate(),
A.OneOf(
[
A.HueSaturationValue(hue_shift_limit=10),
A.RGBShift(),
A.ToGray(),
A.ToSepia(),
A.RandomBrightness(),
A.RandomGamma(),
]
),
A.CoarseDropout(max_height=32, max_width=32, max_holes=5),
]
)
data_augmentation_transform = A.Compose(
[
A.Flip(),
A.ShiftScaleRotate(),
A.OneOf(
[
A.HueSaturationValue(hue_shift_limit=10),
A.RGBShift(),
A.ToGray(),
A.ToSepia(),
A.RandomBrightness(),
A.RandomGamma(),
]
),
A.CoarseDropout(max_height=32, max_width=32, max_holes=5),
]
)
In [ ]:
Copied!
train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=train_image_uri,
label_vector_uri=train_label_uri,
label_vector_default_class_id=class_config.get_class_id("building"),
size_lims=(150, 200),
out_size=256,
max_windows=400,
transform=data_augmentation_transform,
)
len(train_ds)
train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=train_image_uri,
label_vector_uri=train_label_uri,
label_vector_default_class_id=class_config.get_class_id("building"),
size_lims=(150, 200),
out_size=256,
max_windows=400,
transform=data_augmentation_transform,
)
len(train_ds)
In [ ]:
Copied!
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
In [ ]:
Copied!
val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=val_image_uri,
label_vector_uri=val_label_uri,
label_vector_default_class_id=class_config.get_class_id("building"),
size=200,
stride=100,
transform=A.Resize(256, 256),
)
len(val_ds)
val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=val_image_uri,
label_vector_uri=val_label_uri,
label_vector_default_class_id=class_config.get_class_id("building"),
size=200,
stride=100,
transform=A.Resize(256, 256),
)
len(val_ds)
In [ ]:
Copied!
x, y = viz.get_batch(val_ds, 4)
viz.plot_batch(x, y, show=True)
x, y = viz.get_batch(val_ds, 4)
viz.plot_batch(x, y, show=True)
In [ ]:
Copied!
pred_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=pred_image_uri,
size=200,
stride=100,
transform=A.Resize(256, 256),
)
len(pred_ds)
pred_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=pred_image_uri,
size=200,
stride=100,
transform=A.Resize(256, 256),
)
len(pred_ds)
In [ ]:
Copied!
model = torch.hub.load(
"AdeelH/pytorch-fpn:0.3",
"make_fpn_resnet",
name="resnet18",
fpn_type="panoptic",
num_classes=len(class_config),
fpn_channels=128,
in_channels=3,
out_size=(256, 256),
pretrained=True,
)
model = torch.hub.load(
"AdeelH/pytorch-fpn:0.3",
"make_fpn_resnet",
name="resnet18",
fpn_type="panoptic",
num_classes=len(class_config),
fpn_channels=128,
in_channels=3,
out_size=(256, 256),
pretrained=True,
)
In [ ]:
Copied!
data_cfg = SemanticSegmentationGeoDataConfig(
class_names=class_config.names,
class_colors=class_config.colors,
num_workers=0, # increase to use multi-processing
)
data_cfg = SemanticSegmentationGeoDataConfig(
class_names=class_config.names,
class_colors=class_config.colors,
num_workers=0, # increase to use multi-processing
)
In [ ]:
Copied!
solver_cfg = SolverConfig(batch_sz=8, lr=3e-2, class_loss_weights=[1.0, 10.0])
solver_cfg = SolverConfig(batch_sz=8, lr=3e-2, class_loss_weights=[1.0, 10.0])
In [ ]:
Copied!
learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)
learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)
In [ ]:
Copied!
learner = SemanticSegmentationLearner(
cfg=learner_cfg,
output_dir="./train-demo/",
model=model,
train_ds=train_ds,
valid_ds=val_ds,
)
learner = SemanticSegmentationLearner(
cfg=learner_cfg,
output_dir="./train-demo/",
model=model,
train_ds=train_ds,
valid_ds=val_ds,
)
In [ ]:
Copied!
learner.log_data_stats()
learner.log_data_stats()
In [ ]:
Copied!
%load_ext tensorboard
%load_ext tensorboard
In [ ]:
Copied!
%tensorboard --bind_all --logdir "./train-demo/tb-logs" --reload_interval 10
%tensorboard --bind_all --logdir "./train-demo/tb-logs" --reload_interval 10
In [ ]:
Copied!
learner.train(epochs=3)
learner.train(epochs=3)
In [ ]:
Copied!
learner.train(epochs=1)
learner.train(epochs=1)
In [ ]:
Copied!
learner.plot_predictions(split="valid", show=True)
learner.plot_predictions(split="valid", show=True)
In [ ]:
Copied!
learner.save_model_bundle()
learner.save_model_bundle()
In [ ]:
Copied!
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri="./train-demo/model-bundle.zip",
output_dir="./train-demo/",
model=model,
)
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri="./train-demo/model-bundle.zip",
output_dir="./train-demo/",
model=model,
)
In [ ]:
Copied!
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri="./train-demo/model-bundle.zip",
output_dir="./train-demo/",
model=model,
train_ds=train_ds,
valid_ds=val_ds,
training=True,
)
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri="./train-demo/model-bundle.zip",
output_dir="./train-demo/",
model=model,
train_ds=train_ds,
valid_ds=val_ds,
training=True,
)
In [ ]:
Copied!
learner.train(epochs=1)
learner.train(epochs=1)
In [ ]:
Copied!
learner.plot_predictions(split="valid", show=True)
learner.plot_predictions(split="valid", show=True)
In [ ]:
Copied!
predictions = learner.predict_dataset(
pred_ds,
raw_out=True,
numpy_out=True,
predict_kw=dict(out_shape=(325, 325)),
progress_bar=True,
)
predictions = learner.predict_dataset(
pred_ds,
raw_out=True,
numpy_out=True,
predict_kw=dict(out_shape=(325, 325)),
progress_bar=True,
)
In [ ]:
Copied!
pred_labels = SemanticSegmentationLabels.from_predictions(
pred_ds.windows,
predictions,
smooth=True,
extent=pred_ds.scene.extent,
num_classes=len(class_config),
)
pred_labels = SemanticSegmentationLabels.from_predictions(
pred_ds.windows,
predictions,
smooth=True,
extent=pred_ds.scene.extent,
num_classes=len(class_config),
)
In [ ]:
Copied!
scores = pred_labels.get_score_arr(pred_labels.extent)
scores = pred_labels.get_score_arr(pred_labels.extent)
In [ ]:
Copied!
pred_labels.save(
uri=f"predict",
crs_transformer=pred_ds.scene.raster_source.crs_transformer,
class_config=class_config,
)
pred_labels.save(
uri=f"predict",
crs_transformer=pred_ds.scene.raster_source.crs_transformer,
class_config=class_config,
)