train module¶
Compose
¶
Custom compose transform that works with image and target.
Source code in geoai/train.py
class Compose:
"""Custom compose transform that works with image and target."""
def __init__(self, transforms):
"""
Initialize compose transform.
Args:
transforms (list): List of transforms to apply.
"""
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
__init__(self, transforms)
special
¶
Initialize compose transform.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
transforms |
list |
List of transforms to apply. |
required |
Source code in geoai/train.py
def __init__(self, transforms):
"""
Initialize compose transform.
Args:
transforms (list): List of transforms to apply.
"""
self.transforms = transforms
ObjectDetectionDataset (Dataset)
¶
Dataset for object detection from GeoTIFF images and labels.
Source code in geoai/train.py
class ObjectDetectionDataset(Dataset):
"""Dataset for object detection from GeoTIFF images and labels."""
def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
"""
Initialize dataset.
Args:
image_paths (list): List of paths to image GeoTIFF files.
label_paths (list): List of paths to label GeoTIFF files.
transforms (callable, optional): Transformations to apply to images and masks.
num_channels (int, optional): Number of channels to use from images. If None,
auto-detected from the first image.
"""
self.image_paths = image_paths
self.label_paths = label_paths
self.transforms = transforms
# Auto-detect the number of channels if not specified
if num_channels is None:
with rasterio.open(self.image_paths[0]) as src:
self.num_channels = src.count
else:
self.num_channels = num_channels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
with rasterio.open(self.image_paths[idx]) as src:
# Read as [C, H, W] format
image = src.read().astype(np.float32)
# Normalize image to [0, 1] range
image = image / 255.0
# Handle different number of channels
if image.shape[0] > self.num_channels:
image = image[
: self.num_channels
] # Keep only first 4 bands if more exist
elif image.shape[0] < self.num_channels:
# Pad with zeros if less than 4 bands
padded = np.zeros(
(self.num_channels, image.shape[1], image.shape[2]),
dtype=np.float32,
)
padded[: image.shape[0]] = image
image = padded
# Convert to CHW tensor
image = torch.as_tensor(image, dtype=torch.float32)
# Load label mask
with rasterio.open(self.label_paths[idx]) as src:
label_mask = src.read(1)
binary_mask = (label_mask > 0).astype(np.uint8)
# Find all building instances using connected components
labeled_mask, num_instances = measure.label(
binary_mask, return_num=True, connectivity=2
)
# Create list to hold masks for each building instance
masks = []
boxes = []
labels = []
for i in range(1, num_instances + 1):
# Create mask for this instance
instance_mask = (labeled_mask == i).astype(np.uint8)
# Calculate area and filter out tiny instances (noise)
area = instance_mask.sum()
if area < 10: # Minimum area threshold
continue
# Find bounding box coordinates
pos = np.where(instance_mask)
if len(pos[0]) == 0: # Skip if mask is empty
continue
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
# Skip invalid boxes
if xmax <= xmin or ymax <= ymin:
continue
# Add small padding to ensure the mask is within the box
xmin = max(0, xmin - 1)
ymin = max(0, ymin - 1)
xmax = min(binary_mask.shape[1] - 1, xmax + 1)
ymax = min(binary_mask.shape[0] - 1, ymax + 1)
boxes.append([xmin, ymin, xmax, ymax])
masks.append(instance_mask)
labels.append(1) # 1 for building class
# Handle case with no valid instances
if len(boxes) == 0:
# Create a dummy target with minimal required fields
target = {
"boxes": torch.zeros((0, 4), dtype=torch.float32),
"labels": torch.zeros((0), dtype=torch.int64),
"masks": torch.zeros(
(0, binary_mask.shape[0], binary_mask.shape[1]), dtype=torch.uint8
),
"image_id": torch.tensor([idx]),
"area": torch.zeros((0), dtype=torch.float32),
"iscrowd": torch.zeros((0), dtype=torch.int64),
}
else:
# Convert to tensors
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)
# Calculate area of boxes
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# Prepare target dictionary
target = {
"boxes": boxes,
"labels": labels,
"masks": masks,
"image_id": torch.tensor([idx]),
"area": area,
"iscrowd": torch.zeros_like(labels), # Assume no crowd instances
}
# Apply transforms if specified
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
__init__(self, image_paths, label_paths, transforms=None, num_channels=None)
special
¶
Initialize dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_paths |
list |
List of paths to image GeoTIFF files. |
required |
label_paths |
list |
List of paths to label GeoTIFF files. |
required |
transforms |
callable |
Transformations to apply to images and masks. |
None |
num_channels |
int |
Number of channels to use from images. If None, auto-detected from the first image. |
None |
Source code in geoai/train.py
def __init__(self, image_paths, label_paths, transforms=None, num_channels=None):
"""
Initialize dataset.
Args:
image_paths (list): List of paths to image GeoTIFF files.
label_paths (list): List of paths to label GeoTIFF files.
transforms (callable, optional): Transformations to apply to images and masks.
num_channels (int, optional): Number of channels to use from images. If None,
auto-detected from the first image.
"""
self.image_paths = image_paths
self.label_paths = label_paths
self.transforms = transforms
# Auto-detect the number of channels if not specified
if num_channels is None:
with rasterio.open(self.image_paths[0]) as src:
self.num_channels = src.count
else:
self.num_channels = num_channels
RandomHorizontalFlip
¶
Random horizontal flip transform.
Source code in geoai/train.py
class RandomHorizontalFlip:
"""Random horizontal flip transform."""
def __init__(self, prob=0.5):
"""
Initialize random horizontal flip.
Args:
prob (float): Probability of applying the flip.
"""
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
# Flip image
image = torch.flip(image, dims=[2]) # Flip along width dimension
# Flip masks
if "masks" in target and len(target["masks"]) > 0:
target["masks"] = torch.flip(target["masks"], dims=[2])
# Update boxes
if "boxes" in target and len(target["boxes"]) > 0:
boxes = target["boxes"]
width = image.shape[2]
boxes[:, 0], boxes[:, 2] = width - boxes[:, 2], width - boxes[:, 0]
target["boxes"] = boxes
return image, target
__init__(self, prob=0.5)
special
¶
Initialize random horizontal flip.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prob |
float |
Probability of applying the flip. |
0.5 |
Source code in geoai/train.py
def __init__(self, prob=0.5):
"""
Initialize random horizontal flip.
Args:
prob (float): Probability of applying the flip.
"""
self.prob = prob
ToTensor
¶
Convert numpy.ndarray to tensor.
Source code in geoai/train.py
class ToTensor:
"""Convert numpy.ndarray to tensor."""
def __call__(self, image, target):
"""
Apply transform to image and target.
Args:
image (torch.Tensor): Input image.
target (dict): Target annotations.
Returns:
tuple: Transformed image and target.
"""
return image, target
__call__(self, image, target)
special
¶
Apply transform to image and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
torch.Tensor |
Input image. |
required |
target |
dict |
Target annotations. |
required |
Returns:
Type | Description |
---|---|
tuple |
Transformed image and target. |
Source code in geoai/train.py
def __call__(self, image, target):
"""
Apply transform to image and target.
Args:
image (torch.Tensor): Input image.
target (dict): Target annotations.
Returns:
tuple: Transformed image and target.
"""
return image, target
collate_fn(batch)
¶
Custom collate function for batching samples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
list |
List of (image, target) tuples. |
required |
Returns:
Type | Description |
---|---|
tuple |
Tuple of images and targets. |
Source code in geoai/train.py
def collate_fn(batch):
"""
Custom collate function for batching samples.
Args:
batch (list): List of (image, target) tuples.
Returns:
tuple: Tuple of images and targets.
"""
return tuple(zip(*batch))
evaluate(model, data_loader, device)
¶
Evaluate the model on the validation set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
The model to evaluate. |
required |
data_loader |
torch.utils.data.DataLoader |
DataLoader for validation data. |
required |
device |
torch.device |
Device to evaluate on. |
required |
Returns:
Type | Description |
---|---|
dict |
Evaluation metrics including loss and IoU. |
Source code in geoai/train.py
def evaluate(model, data_loader, device):
"""
Evaluate the model on the validation set.
Args:
model (torch.nn.Module): The model to evaluate.
data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
device (torch.device): Device to evaluate on.
Returns:
dict: Evaluation metrics including loss and IoU.
"""
model.eval()
# Initialize metrics
total_loss = 0
iou_scores = []
with torch.no_grad():
for images, targets in data_loader:
# Move to device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# During evaluation, Mask R-CNN directly returns predictions, not losses
# So we'll only get loss when we provide targets explicitly
if len(targets) > 0:
try:
# Try to get loss dict (this works in some implementations)
loss_dict = model(images, targets)
if isinstance(loss_dict, dict):
losses = sum(loss for loss in loss_dict.values())
total_loss += losses.item()
except Exception as e:
print(f"Warning: Could not compute loss during evaluation: {e}")
# If we can't compute loss, we'll just focus on IoU
pass
# Get predictions
outputs = model(images)
# Calculate IoU for each image
for i, output in enumerate(outputs):
if len(output["masks"]) == 0 or len(targets[i]["masks"]) == 0:
continue
# Convert predicted masks to binary (threshold at 0.5)
pred_masks = (output["masks"].squeeze(1) > 0.5).float()
# Combine all instance masks into a single binary mask
pred_combined = (
torch.max(pred_masks, dim=0)[0]
if pred_masks.shape[0] > 0
else torch.zeros_like(targets[i]["masks"][0])
)
target_combined = (
torch.max(targets[i]["masks"], dim=0)[0]
if targets[i]["masks"].shape[0] > 0
else torch.zeros_like(pred_combined)
)
# Calculate IoU
intersection = (pred_combined * target_combined).sum().item()
union = ((pred_combined + target_combined) > 0).sum().item()
if union > 0:
iou = intersection / union
iou_scores.append(iou)
# Calculate metrics
avg_loss = total_loss / len(data_loader) if total_loss > 0 else float("inf")
avg_iou = sum(iou_scores) / len(iou_scores) if iou_scores else 0
return {"loss": avg_loss, "IoU": avg_iou}
get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True)
¶
Get Mask R-CNN model with custom input channels and output classes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
int |
Number of output classes (including background). |
2 |
num_channels |
int |
Number of input channels (3 for RGB, 4 for RGBN). |
3 |
pretrained |
bool |
Whether to use pretrained backbone. |
True |
Returns:
Type | Description |
---|---|
torch.nn.Module |
Mask R-CNN model with specified input channels and output classes. |
Exceptions:
Type | Description |
---|---|
ValueError |
If num_channels is less than 3. |
Source code in geoai/train.py
def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
"""
Get Mask R-CNN model with custom input channels and output classes.
Args:
num_classes (int): Number of output classes (including background).
num_channels (int): Number of input channels (3 for RGB, 4 for RGBN).
pretrained (bool): Whether to use pretrained backbone.
Returns:
torch.nn.Module: Mask R-CNN model with specified input channels and output classes.
Raises:
ValueError: If num_channels is less than 3.
"""
# Validate num_channels
if num_channels < 3:
raise ValueError("num_channels must be at least 3")
# Load pre-trained model
model = maskrcnn_resnet50_fpn(
pretrained=pretrained,
progress=True,
weights=(
torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT
if pretrained
else None
),
)
# Modify transform if num_channels is different from 3
if num_channels != 3:
# Get the transform
transform = model.transform
# Default values are [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225]
# Calculate means and stds for additional channels
rgb_mean = [0.485, 0.456, 0.406]
rgb_std = [0.229, 0.224, 0.225]
# Extend them to num_channels (use the mean value for additional channels)
mean_of_means = sum(rgb_mean) / len(rgb_mean)
mean_of_stds = sum(rgb_std) / len(rgb_std)
# Create new lists with appropriate length
transform.image_mean = rgb_mean + [mean_of_means] * (num_channels - 3)
transform.image_std = rgb_std + [mean_of_stds] * (num_channels - 3)
# Get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Get number of input features for mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# Replace mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)
# Modify the first layer if num_channels is different from 3
if num_channels != 3:
original_layer = model.backbone.body.conv1
model.backbone.body.conv1 = torch.nn.Conv2d(
num_channels,
original_layer.out_channels,
kernel_size=original_layer.kernel_size,
stride=original_layer.stride,
padding=original_layer.padding,
bias=original_layer.bias is not None,
)
# Copy weights from the original 3 channels to the new layer
with torch.no_grad():
# Copy the weights for the first 3 channels
model.backbone.body.conv1.weight[:, :3, :, :] = original_layer.weight
# Initialize additional channels with the mean of the first 3 channels
mean_weight = original_layer.weight.mean(dim=1, keepdim=True)
for i in range(3, num_channels):
model.backbone.body.conv1.weight[:, i : i + 1, :, :] = mean_weight
# Copy bias if it exists
if original_layer.bias is not None:
model.backbone.body.conv1.bias = original_layer.bias
return model
get_transform(train)
¶
Get transforms for data augmentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train |
bool |
Whether to include training-specific transforms. |
required |
Returns:
Type | Description |
---|---|
Compose |
Composed transforms. |
Source code in geoai/train.py
def get_transform(train):
"""
Get transforms for data augmentation.
Args:
train (bool): Whether to include training-specific transforms.
Returns:
Compose: Composed transforms.
"""
transforms = []
transforms.append(ToTensor())
if train:
transforms.append(RandomHorizontalFlip(0.5))
return Compose(transforms)
inference_on_geotiff(model, geotiff_path, output_path, window_size=512, overlap=256, confidence_threshold=0.5, batch_size=4, num_channels=3, device=None, **kwargs)
¶
Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
Trained model for inference. |
required |
geotiff_path |
str |
Path to input GeoTIFF file. |
required |
output_path |
str |
Path to save output mask GeoTIFF. |
required |
window_size |
int |
Size of sliding window for inference. |
512 |
overlap |
int |
Overlap between adjacent windows. |
256 |
confidence_threshold |
float |
Confidence threshold for predictions (0-1). |
0.5 |
batch_size |
int |
Batch size for inference. |
4 |
num_channels |
int |
Number of channels to use from the input image. |
3 |
device |
torch.device |
Device to run inference on. If None, uses CUDA if available. |
None |
**kwargs |
Additional arguments. |
{} |
Returns:
Type | Description |
---|---|
tuple |
Tuple containing output path and inference time in seconds. |
Source code in geoai/train.py
def inference_on_geotiff(
model,
geotiff_path,
output_path,
window_size=512,
overlap=256,
confidence_threshold=0.5,
batch_size=4,
num_channels=3,
device=None,
**kwargs,
):
"""
Perform inference on a large GeoTIFF using a sliding window approach with improved blending.
Args:
model (torch.nn.Module): Trained model for inference.
geotiff_path (str): Path to input GeoTIFF file.
output_path (str): Path to save output mask GeoTIFF.
window_size (int): Size of sliding window for inference.
overlap (int): Overlap between adjacent windows.
confidence_threshold (float): Confidence threshold for predictions (0-1).
batch_size (int): Batch size for inference.
num_channels (int): Number of channels to use from the input image.
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
**kwargs: Additional arguments.
Returns:
tuple: Tuple containing output path and inference time in seconds.
"""
if device is None:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
# Put model in evaluation mode
model.to(device)
model.eval()
# Open the GeoTIFF
with rasterio.open(geotiff_path) as src:
# Read metadata
meta = src.meta
height = src.height
width = src.width
# Update metadata for output raster
out_meta = meta.copy()
out_meta.update(
{"count": 1, "dtype": "uint8"} # Single band for mask # Binary mask
)
# We'll use two arrays:
# 1. For accumulating predictions
pred_accumulator = np.zeros((height, width), dtype=np.float32)
# 2. For tracking how many predictions contribute to each pixel
count_accumulator = np.zeros((height, width), dtype=np.float32)
# Calculate the number of windows needed to cover the entire image
steps_y = math.ceil((height - overlap) / (window_size - overlap))
steps_x = math.ceil((width - overlap) / (window_size - overlap))
# Ensure we cover the entire image
last_y = height - window_size
last_x = width - window_size
total_windows = steps_y * steps_x
print(
f"Processing {total_windows} windows with size {window_size}x{window_size} and overlap {overlap}..."
)
# Create progress bar
pbar = tqdm(total=total_windows)
# Process in batches
batch_inputs = []
batch_positions = []
batch_count = 0
start_time = time.time()
# Slide window over the image - make sure we cover the entire image
for i in range(steps_y + 1): # +1 to ensure we reach the edge
y = min(i * (window_size - overlap), last_y)
y = max(0, y) # Prevent negative indices
if y > last_y and i > 0: # Skip if we've already covered the entire height
continue
for j in range(steps_x + 1): # +1 to ensure we reach the edge
x = min(j * (window_size - overlap), last_x)
x = max(0, x) # Prevent negative indices
if (
x > last_x and j > 0
): # Skip if we've already covered the entire width
continue
# Read window
window = src.read(window=Window(x, y, window_size, window_size))
# Check if window is valid
if window.shape[1] != window_size or window.shape[2] != window_size:
# This can happen at image edges - adjust window size
current_height = window.shape[1]
current_width = window.shape[2]
if current_height == 0 or current_width == 0:
continue # Skip empty windows
else:
current_height = window_size
current_width = window_size
# Normalize and prepare input
image = window.astype(np.float32) / 255.0
# Handle different number of bands
if image.shape[0] > num_channels:
image = image[:num_channels]
elif image.shape[0] < num_channels:
padded = np.zeros(
(num_channels, current_height, current_width), dtype=np.float32
)
padded[: image.shape[0]] = image
image = padded
# Convert to tensor
image_tensor = torch.tensor(image, device=device)
# Add to batch
batch_inputs.append(image_tensor)
batch_positions.append((y, x, current_height, current_width))
batch_count += 1
# Process batch when it reaches the batch size or at the end
if batch_count == batch_size or (i == steps_y and j == steps_x):
# Forward pass
with torch.no_grad():
outputs = model(batch_inputs)
# Process each output in the batch
for idx, output in enumerate(outputs):
y_pos, x_pos, h, w = batch_positions[idx]
# Create weight matrix that gives higher weight to center pixels
# This helps with smooth blending at boundaries
y_grid, x_grid = np.mgrid[0:h, 0:w]
# Calculate distance from each edge
dist_from_left = x_grid
dist_from_right = w - x_grid - 1
dist_from_top = y_grid
dist_from_bottom = h - y_grid - 1
# Combine distances (minimum distance to any edge)
edge_distance = np.minimum.reduce(
[
dist_from_left,
dist_from_right,
dist_from_top,
dist_from_bottom,
]
)
# Convert to weight (higher weight for center pixels)
# Normalize to [0, 1]
edge_distance = np.minimum(edge_distance, overlap / 2)
weight = edge_distance / (overlap / 2)
# Get masks for predictions above threshold
if len(output["scores"]) > 0:
# Get all instances that meet confidence threshold
keep = output["scores"] > confidence_threshold
masks = output["masks"][keep].squeeze(1)
# Combine all instances into one mask
if len(masks) > 0:
combined_mask = torch.max(masks, dim=0)[0] > 0.5
combined_mask = (
combined_mask.cpu().numpy().astype(np.float32)
)
# Apply weight to prediction
weighted_pred = combined_mask * weight
# Add to accumulators
pred_accumulator[
y_pos : y_pos + h, x_pos : x_pos + w
] += weighted_pred
count_accumulator[
y_pos : y_pos + h, x_pos : x_pos + w
] += weight
# Reset batch
batch_inputs = []
batch_positions = []
batch_count = 0
# Update progress bar
pbar.update(len(outputs))
# Close progress bar
pbar.close()
# Calculate final mask by dividing accumulated predictions by counts
# Handle division by zero
mask = np.zeros((height, width), dtype=np.uint8)
valid_pixels = count_accumulator > 0
if np.any(valid_pixels):
# Average predictions where we have data
mask[valid_pixels] = (
pred_accumulator[valid_pixels] / count_accumulator[valid_pixels] > 0.5
).astype(np.uint8)
# Record time
inference_time = time.time() - start_time
print(f"Inference completed in {inference_time:.2f} seconds")
# Save output
with rasterio.open(output_path, "w", **out_meta) as dst:
dst.write(mask, 1)
print(f"Saved prediction to {output_path}")
return output_path, inference_time
object_detection(input_path, output_path, model_path, window_size=512, overlap=256, confidence_threshold=0.5, batch_size=4, num_channels=3, pretrained=True, device=None, **kwargs)
¶
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_path |
str |
Path to input GeoTIFF file. |
required |
output_path |
str |
Path to save output mask GeoTIFF. |
required |
model_path |
str |
Path to trained model weights. |
required |
window_size |
int |
Size of sliding window for inference. |
512 |
overlap |
int |
Overlap between adjacent windows. |
256 |
confidence_threshold |
float |
Confidence threshold for predictions (0-1). |
0.5 |
batch_size |
int |
Batch size for inference. |
4 |
num_channels |
int |
Number of channels in the input image and model. |
3 |
pretrained |
bool |
Whether to use pretrained backbone for model loading. |
True |
device |
torch.device |
Device to run inference on. If None, uses CUDA if available. |
None |
**kwargs |
Additional arguments passed to inference_on_geotiff. |
{} |
Returns:
Type | Description |
---|---|
None |
Output mask is saved to output_path. |
Source code in geoai/train.py
def object_detection(
input_path,
output_path,
model_path,
window_size=512,
overlap=256,
confidence_threshold=0.5,
batch_size=4,
num_channels=3,
pretrained=True,
device=None,
**kwargs,
):
"""
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
Args:
input_path (str): Path to input GeoTIFF file.
output_path (str): Path to save output mask GeoTIFF.
model_path (str): Path to trained model weights.
window_size (int): Size of sliding window for inference.
overlap (int): Overlap between adjacent windows.
confidence_threshold (float): Confidence threshold for predictions (0-1).
batch_size (int): Batch size for inference.
num_channels (int): Number of channels in the input image and model.
pretrained (bool): Whether to use pretrained backbone for model loading.
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
**kwargs: Additional arguments passed to inference_on_geotiff.
Returns:
None: Output mask is saved to output_path.
"""
# Load your trained model
if device is None:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
model = get_instance_segmentation_model(
num_classes=2, num_channels=num_channels, pretrained=pretrained
)
if not os.path.exists(model_path):
try:
model_path = download_model_from_hf(model_path)
except Exception as e:
raise FileNotFoundError(f"Model file not found: {model_path}")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
inference_on_geotiff(
model=model,
geotiff_path=input_path,
output_path=output_path,
window_size=window_size, # Adjust based on your model and memory
overlap=overlap, # Overlap to avoid edge artifacts
confidence_threshold=confidence_threshold,
batch_size=batch_size, # Adjust based on your GPU memory
num_channels=num_channels,
device=device,
**kwargs,
)
object_detection_batch(input_paths, output_dir, model_path, filenames=None, window_size=512, overlap=256, confidence_threshold=0.5, batch_size=4, num_channels=3, pretrained=True, device=None, **kwargs)
¶
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_paths |
str or list |
Path(s) to input GeoTIFF file(s). If a directory is provided, all .tif files in that directory will be processed. |
required |
output_dir |
str |
Directory to save output mask GeoTIFF files. |
required |
model_path |
str |
Path to trained model weights. |
required |
filenames |
list |
List of output filenames. If None, defaults to
" |
None |
window_size |
int |
Size of sliding window for inference. |
512 |
overlap |
int |
Overlap between adjacent windows. |
256 |
confidence_threshold |
float |
Confidence threshold for predictions (0-1). |
0.5 |
batch_size |
int |
Batch size for inference. |
4 |
num_channels |
int |
Number of channels in the input image and model. |
3 |
pretrained |
bool |
Whether to use pretrained backbone for model loading. |
True |
device |
torch.device |
Device to run inference on. If None, uses CUDA if available. |
None |
**kwargs |
Additional arguments passed to inference_on_geotiff. |
{} |
Returns:
Type | Description |
---|---|
None |
Output mask is saved to output_path. |
Source code in geoai/train.py
def object_detection_batch(
input_paths,
output_dir,
model_path,
filenames=None,
window_size=512,
overlap=256,
confidence_threshold=0.5,
batch_size=4,
num_channels=3,
pretrained=True,
device=None,
**kwargs,
):
"""
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
Args:
input_paths (str or list): Path(s) to input GeoTIFF file(s). If a directory is provided,
all .tif files in that directory will be processed.
output_dir (str): Directory to save output mask GeoTIFF files.
model_path (str): Path to trained model weights.
filenames (list, optional): List of output filenames. If None, defaults to
"<input_filename>_mask.tif" for each input file.
If provided, must match the number of input files.
window_size (int): Size of sliding window for inference.
overlap (int): Overlap between adjacent windows.
confidence_threshold (float): Confidence threshold for predictions (0-1).
batch_size (int): Batch size for inference.
num_channels (int): Number of channels in the input image and model.
pretrained (bool): Whether to use pretrained backbone for model loading.
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
**kwargs: Additional arguments passed to inference_on_geotiff.
Returns:
None: Output mask is saved to output_path.
"""
# Load your trained model
if device is None:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
model = get_instance_segmentation_model(
num_classes=2, num_channels=num_channels, pretrained=pretrained
)
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
if not os.path.exists(model_path):
try:
model_path = download_model_from_hf(model_path)
except Exception as e:
raise FileNotFoundError(f"Model file not found: {model_path}")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
if isinstance(input_paths, str) and (not input_paths.endswith(".tif")):
files = glob.glob(os.path.join(input_paths, "*.tif"))
files.sort()
elif isinstance(input_paths, str):
files = [input_paths]
if filenames is None:
filenames = [
os.path.join(output_dir, os.path.basename(f).replace(".tif", "_mask.tif"))
for f in files
]
else:
if len(filenames) != len(files):
raise ValueError("Number of filenames must match number of input files.")
for index, file in enumerate(files):
print(f"Processing file {index + 1}/{len(files)}: {file}")
inference_on_geotiff(
model=model,
geotiff_path=file,
output_path=filenames[index],
window_size=window_size, # Adjust based on your model and memory
overlap=overlap, # Overlap to avoid edge artifacts
confidence_threshold=confidence_threshold,
batch_size=batch_size, # Adjust based on your GPU memory
num_channels=num_channels,
device=device,
**kwargs,
)
train_MaskRCNN_model(images_dir, labels_dir, output_dir, num_channels=3, pretrained=True, pretrained_model_path=None, batch_size=4, num_epochs=10, learning_rate=0.005, seed=42, val_split=0.2, visualize=False, resume_training=False, print_freq=10, verbose=True)
¶
Train and evaluate Mask R-CNN model for instance segmentation.
This function trains a Mask R-CNN model for instance segmentation using the provided dataset. It supports loading a pretrained model to either initialize the backbone or to continue training from a specific checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
images_dir |
str |
Directory containing image GeoTIFF files. |
required |
labels_dir |
str |
Directory containing label GeoTIFF files. |
required |
output_dir |
str |
Directory to save model checkpoints and results. |
required |
num_channels |
int |
Number of input channels. If None, auto-detected. Defaults to 3. |
3 |
pretrained |
bool |
Whether to use pretrained backbone. This is ignored if pretrained_model_path is provided. Defaults to True. |
True |
pretrained_model_path |
str |
Path to a .pth file to load as a pretrained model for continued training. Defaults to None. |
None |
batch_size |
int |
Batch size for training. Defaults to 4. |
4 |
num_epochs |
int |
Number of training epochs. Defaults to 10. |
10 |
learning_rate |
float |
Initial learning rate. Defaults to 0.005. |
0.005 |
seed |
int |
Random seed for reproducibility. Defaults to 42. |
42 |
val_split |
float |
Fraction of data to use for validation (0-1). Defaults to 0.2. |
0.2 |
visualize |
bool |
Whether to generate visualizations of model predictions. Defaults to False. |
False |
resume_training |
bool |
If True and pretrained_model_path is provided, will try to load optimizer and scheduler states as well. Defaults to False. |
False |
print_freq |
int |
Frequency of printing training progress. Defaults to 10. |
10 |
verbose |
bool |
If True, prints detailed training progress. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
None |
Model weights are saved to output_dir. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If pretrained_model_path is provided but file doesn't exist. |
RuntimeError |
If there's an issue loading the pretrained model. |
Source code in geoai/train.py
def train_MaskRCNN_model(
images_dir,
labels_dir,
output_dir,
num_channels=3,
pretrained=True,
pretrained_model_path=None,
batch_size=4,
num_epochs=10,
learning_rate=0.005,
seed=42,
val_split=0.2,
visualize=False,
resume_training=False,
print_freq=10,
verbose=True,
):
"""Train and evaluate Mask R-CNN model for instance segmentation.
This function trains a Mask R-CNN model for instance segmentation using the
provided dataset. It supports loading a pretrained model to either initialize
the backbone or to continue training from a specific checkpoint.
Args:
images_dir (str): Directory containing image GeoTIFF files.
labels_dir (str): Directory containing label GeoTIFF files.
output_dir (str): Directory to save model checkpoints and results.
num_channels (int, optional): Number of input channels. If None, auto-detected.
Defaults to 3.
pretrained (bool): Whether to use pretrained backbone. This is ignored if
pretrained_model_path is provided. Defaults to True.
pretrained_model_path (str, optional): Path to a .pth file to load as a
pretrained model for continued training. Defaults to None.
batch_size (int): Batch size for training. Defaults to 4.
num_epochs (int): Number of training epochs. Defaults to 10.
learning_rate (float): Initial learning rate. Defaults to 0.005.
seed (int): Random seed for reproducibility. Defaults to 42.
val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
visualize (bool): Whether to generate visualizations of model predictions.
Defaults to False.
resume_training (bool): If True and pretrained_model_path is provided,
will try to load optimizer and scheduler states as well. Defaults to False.
print_freq (int): Frequency of printing training progress. Defaults to 10.
verbose (bool): If True, prints detailed training progress. Defaults to True.
Returns:
None: Model weights are saved to output_dir.
Raises:
FileNotFoundError: If pretrained_model_path is provided but file doesn't exist.
RuntimeError: If there's an issue loading the pretrained model.
"""
import datetime
# Set random seeds for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Get device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
# Get all image and label files
image_files = sorted(
[
os.path.join(images_dir, f)
for f in os.listdir(images_dir)
if f.endswith(".tif")
]
)
label_files = sorted(
[
os.path.join(labels_dir, f)
for f in os.listdir(labels_dir)
if f.endswith(".tif")
]
)
print(f"Found {len(image_files)} image files and {len(label_files)} label files")
# Ensure matching files
if len(image_files) != len(label_files):
print("Warning: Number of image files and label files don't match!")
# Find matching files by basename
basenames = [os.path.basename(f) for f in image_files]
label_files = [
os.path.join(labels_dir, os.path.basename(f))
for f in image_files
if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
]
image_files = [
f
for f, b in zip(image_files, basenames)
if os.path.exists(os.path.join(labels_dir, b))
]
print(f"Using {len(image_files)} matching files")
# Split data into train and validation sets
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
image_files, label_files, test_size=val_split, random_state=seed
)
print(f"Training on {len(train_imgs)} images, validating on {len(val_imgs)} images")
# Create datasets
train_dataset = ObjectDetectionDataset(
train_imgs, train_labels, transforms=get_transform(train=True)
)
val_dataset = ObjectDetectionDataset(
val_imgs, val_labels, transforms=get_transform(train=False)
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=4,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=4,
)
# Initialize model (2 classes: background and building)
model = get_instance_segmentation_model(
num_classes=2, num_channels=num_channels, pretrained=pretrained
)
model.to(device)
# Set up optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
)
# Set up learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
# Initialize training variables
start_epoch = 0
best_iou = 0
# Load pretrained model if provided
if pretrained_model_path:
if not os.path.exists(pretrained_model_path):
raise FileNotFoundError(
f"Pretrained model file not found: {pretrained_model_path}"
)
print(f"Loading pretrained model from: {pretrained_model_path}")
try:
# Check if it's a full checkpoint or just model weights
checkpoint = torch.load(pretrained_model_path, map_location=device)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
# It's a checkpoint with extra information
model.load_state_dict(checkpoint["model_state_dict"])
if resume_training:
# Resume from checkpoint
start_epoch = checkpoint.get("epoch", 0) + 1
best_iou = checkpoint.get("best_iou", 0)
if "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "scheduler_state_dict" in checkpoint:
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
print(f"Resuming training from epoch {start_epoch}")
print(f"Previous best IoU: {best_iou:.4f}")
else:
# Assume it's just the model weights
model.load_state_dict(checkpoint)
print("Pretrained model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load pretrained model: {str(e)}")
# Training loop
for epoch in range(start_epoch, num_epochs):
# Train one epoch
train_loss = train_one_epoch(
model, optimizer, train_loader, device, epoch, print_freq, verbose
)
# Update learning rate
lr_scheduler.step()
# Evaluate
eval_metrics = evaluate(model, val_loader, device)
# Print metrics
print(
f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {eval_metrics['loss']:.4f}, Val IoU: {eval_metrics['IoU']:.4f}"
)
# Save best model
if eval_metrics["IoU"] > best_iou:
best_iou = eval_metrics["IoU"]
print(f"Saving best model with IoU: {best_iou:.4f}")
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
# Save checkpoint every 10 epochs
if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"best_iou": best_iou,
},
os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"),
)
# Save final model
torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
# Save full checkpoint of final state
torch.save(
{
"epoch": num_epochs - 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"best_iou": best_iou,
},
os.path.join(output_dir, "final_checkpoint.pth"),
)
# Load best model for evaluation and visualization
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
# Final evaluation
final_metrics = evaluate(model, val_loader, device)
print(
f"Final Evaluation - Loss: {final_metrics['loss']:.4f}, IoU: {final_metrics['IoU']:.4f}"
)
# Visualize results
if visualize:
print("Generating visualizations...")
visualize_predictions(
model,
val_dataset,
device,
num_samples=5,
output_dir=os.path.join(output_dir, "visualizations"),
)
# Save training summary
with open(os.path.join(output_dir, "training_summary.txt"), "w") as f:
f.write(
f"Training completed on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
)
f.write(f"Total epochs: {num_epochs}\n")
f.write(f"Best validation IoU: {best_iou:.4f}\n")
f.write(f"Final validation IoU: {final_metrics['IoU']:.4f}\n")
f.write(f"Final validation loss: {final_metrics['loss']:.4f}\n")
if pretrained_model_path:
f.write(f"Started from pretrained model: {pretrained_model_path}\n")
if resume_training:
f.write(f"Resumed training from epoch {start_epoch}\n")
print(f"Training complete! Trained model saved to {output_dir}")
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10, verbose=True)
¶
Train the model for one epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
The model to train. |
required |
optimizer |
torch.optim.Optimizer |
The optimizer to use. |
required |
data_loader |
torch.utils.data.DataLoader |
DataLoader for training data. |
required |
device |
torch.device |
Device to train on. |
required |
epoch |
int |
Current epoch number. |
required |
print_freq |
int |
How often to print progress. |
10 |
verbose |
bool |
Whether to print detailed progress. |
True |
Returns:
Type | Description |
---|---|
float |
Average loss for the epoch. |
Source code in geoai/train.py
def train_one_epoch(
model, optimizer, data_loader, device, epoch, print_freq=10, verbose=True
):
"""
Train the model for one epoch.
Args:
model (torch.nn.Module): The model to train.
optimizer (torch.optim.Optimizer): The optimizer to use.
data_loader (torch.utils.data.DataLoader): DataLoader for training data.
device (torch.device): Device to train on.
epoch (int): Current epoch number.
print_freq (int): How often to print progress.
verbose (bool): Whether to print detailed progress.
Returns:
float: Average loss for the epoch.
"""
model.train()
total_loss = 0
start_time = time.time()
for i, (images, targets) in enumerate(data_loader):
# Move images and targets to device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# Backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Track loss
total_loss += losses.item()
# Print progress
if i % print_freq == 0:
elapsed_time = time.time() - start_time
if verbose:
print(
f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
)
start_time = time.time()
# Calculate average loss
avg_loss = total_loss / len(data_loader)
return avg_loss
visualize_predictions(model, dataset, device, num_samples=5, output_dir=None)
¶
Visualize model predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
Trained model. |
required |
dataset |
torch.utils.data.Dataset |
Dataset to visualize. |
required |
device |
torch.device |
Device to run inference on. |
required |
num_samples |
int |
Number of samples to visualize. |
5 |
output_dir |
str |
Directory to save visualizations. If None, visualizations are displayed but not saved. |
None |
Source code in geoai/train.py
def visualize_predictions(model, dataset, device, num_samples=5, output_dir=None):
"""
Visualize model predictions.
Args:
model (torch.nn.Module): Trained model.
dataset (torch.utils.data.Dataset): Dataset to visualize.
device (torch.device): Device to run inference on.
num_samples (int): Number of samples to visualize.
output_dir (str, optional): Directory to save visualizations. If None,
visualizations are displayed but not saved.
"""
model.eval()
# Create output directory if needed
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Select random samples
indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
for idx in indices:
# Get image and target
image, target = dataset[idx]
# Convert to device and add batch dimension
image = image.to(device)
image_batch = [image]
# Get prediction
with torch.no_grad():
output = model(image_batch)[0]
# Convert image from CHW to HWC for display (first 3 bands as RGB)
rgb_image = image[:3].cpu().numpy()
rgb_image = np.transpose(rgb_image, (1, 2, 0))
rgb_image = np.clip(rgb_image, 0, 1) # Ensure values are in [0,1]
# Create binary ground truth mask (combine all instances)
gt_masks = target["masks"].cpu().numpy()
gt_combined = (
np.max(gt_masks, axis=0)
if len(gt_masks) > 0
else np.zeros((image.shape[1], image.shape[2]), dtype=np.uint8)
)
# Create binary prediction mask (combine all instances with score > 0.5)
pred_masks = output["masks"].cpu().numpy()
pred_scores = output["scores"].cpu().numpy()
high_conf_indices = pred_scores > 0.5
pred_combined = np.zeros((image.shape[1], image.shape[2]), dtype=np.float32)
if np.any(high_conf_indices):
for mask in pred_masks[high_conf_indices]:
# Apply threshold to each predicted mask
binary_mask = (mask[0] > 0.5).astype(np.float32)
# Combine with existing masks
pred_combined = np.maximum(pred_combined, binary_mask)
# Create figure
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
# Show RGB image
axs[0].imshow(rgb_image)
axs[0].set_title("RGB Image")
axs[0].axis("off")
# Show prediction
axs[1].imshow(pred_combined, cmap="viridis")
axs[1].set_title(f"Predicted Buildings: {np.sum(high_conf_indices)} instances")
axs[1].axis("off")
# Show ground truth
axs[2].imshow(gt_combined, cmap="viridis")
axs[2].set_title(f"Ground Truth: {len(gt_masks)} instances")
axs[2].axis("off")
plt.tight_layout()
# Save or show
if output_dir:
plt.savefig(os.path.join(output_dir, f"prediction_{idx}.png"))
plt.close()
else:
plt.show()