Skip to content

classify module

The module for training semantic segmentation models for classifying remote sensing imagery.

classify_image(image_path, model_path, output_path=None, chip_size=1024, overlap=256, batch_size=4, colormap=None, **kwargs)

Classify a geospatial image using a trained semantic segmentation model.

This function handles the full image classification pipeline with special attention to edge handling: 1. Process the image in a grid pattern with overlapping tiles 2. Use central regions of tiles for interior parts 3. Special handling for edges to ensure complete coverage 4. Merge results into a single georeferenced output

Parameters:

Name Type Description Default
image_path str

Path to the input GeoTIFF image.

required
model_path str

Path to the trained model checkpoint.

required
output_path str

Path to save the output classified image. Defaults to "[input_name]_classified.tif".

None
chip_size int

Size of chips for processing. Defaults to 1024.

1024
overlap int

Overlap size between adjacent tiles. Defaults to 256.

256
batch_size int

Batch size for inference. Defaults to 4.

4
colormap dict

Colormap to apply to the output image. Defaults to None.

None
**kwargs

Additional keyword arguments for DataLoader.

{}

Returns:

Name Type Description
str

Path to the saved classified image.

Source code in geoai/classify.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
def classify_image(
    image_path,
    model_path,
    output_path=None,
    chip_size=1024,
    overlap=256,
    batch_size=4,
    colormap=None,
    **kwargs,
):
    """
    Classify a geospatial image using a trained semantic segmentation model.

    This function handles the full image classification pipeline with special
    attention to edge handling:
    1. Process the image in a grid pattern with overlapping tiles
    2. Use central regions of tiles for interior parts
    3. Special handling for edges to ensure complete coverage
    4. Merge results into a single georeferenced output

    Parameters:
        image_path (str): Path to the input GeoTIFF image.
        model_path (str): Path to the trained model checkpoint.
        output_path (str, optional): Path to save the output classified image.
                                    Defaults to "[input_name]_classified.tif".
        chip_size (int, optional): Size of chips for processing. Defaults to 1024.
        overlap (int, optional): Overlap size between adjacent tiles. Defaults to 256.
        batch_size (int, optional): Batch size for inference. Defaults to 4.
        colormap (dict, optional): Colormap to apply to the output image.
                                   Defaults to None.
        **kwargs: Additional keyword arguments for DataLoader.

    Returns:
        str: Path to the saved classified image.
    """
    import timeit
    import warnings

    import rasterio
    import torch
    from rasterio.errors import NotGeoreferencedWarning
    from torchgeo.trainers import SemanticSegmentationTask

    # Disable specific GDAL/rasterio warnings
    warnings.filterwarnings("ignore", category=UserWarning, module="rasterio._.*")
    warnings.filterwarnings("ignore", category=UserWarning, module="rasterio")
    warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

    # Also suppress GDAL error reports
    import logging

    logging.getLogger("rasterio").setLevel(logging.ERROR)

    # Set default output path if not provided
    if output_path is None:
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        output_path = f"{base_name}_classified.tif"

    # Make sure output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Load the model
    print(f"Loading model from {model_path}...")
    task = SemanticSegmentationTask.load_from_checkpoint(model_path)
    task.model.eval()
    task.model.cuda()

    # Process the image using a modified tiling approach
    with rasterio.open(image_path) as src:
        # Get image dimensions and metadata
        height = src.height
        width = src.width
        profile = src.profile.copy()

        # Prepare output array for the final result
        output_image = np.zeros((height, width), dtype=np.uint8)
        confidence_map = np.zeros((height, width), dtype=np.float32)

        # Calculate number of tiles needed with overlap
        # Ensure we have tiles that specifically cover the edges
        effective_stride = chip_size - overlap

        # Calculate x positions ensuring leftmost and rightmost edges are covered
        x_positions = []
        # Always include the leftmost position
        x_positions.append(0)
        # Add regular grid positions
        for x in range(effective_stride, width - chip_size, effective_stride):
            x_positions.append(x)
        # Always include rightmost position that still fits
        if width > chip_size and x_positions[-1] + chip_size < width:
            x_positions.append(width - chip_size)

        # Calculate y positions ensuring top and bottom edges are covered
        y_positions = []
        # Always include the topmost position
        y_positions.append(0)
        # Add regular grid positions
        for y in range(effective_stride, height - chip_size, effective_stride):
            y_positions.append(y)
        # Always include bottommost position that still fits
        if height > chip_size and y_positions[-1] + chip_size < height:
            y_positions.append(height - chip_size)

        # Create list of all tile positions
        tile_positions = []
        for y in y_positions:
            for x in x_positions:
                y_end = min(y + chip_size, height)
                x_end = min(x + chip_size, width)
                tile_positions.append((y, x, y_end, x_end))

        # Print information about the tiling
        print(
            f"Processing {len(tile_positions)} patches covering an image of size {height}x{width}..."
        )
        start_time = timeit.default_timer()

        # Process tiles in batches
        for batch_start in range(0, len(tile_positions), batch_size):
            batch_end = min(batch_start + batch_size, len(tile_positions))
            batch_positions = tile_positions[batch_start:batch_end]
            batch_data = []

            # Load data for current batch
            for y_start, x_start, y_end, x_end in batch_positions:
                # Calculate actual tile size
                actual_height = y_end - y_start
                actual_width = x_end - x_start

                # Read the tile data
                tile_data = src.read(window=((y_start, y_end), (x_start, x_end)))

                # Handle different sized tiles by padding if necessary
                if tile_data.shape[1] != chip_size or tile_data.shape[2] != chip_size:
                    padded_data = np.zeros(
                        (tile_data.shape[0], chip_size, chip_size),
                        dtype=tile_data.dtype,
                    )
                    padded_data[:, : tile_data.shape[1], : tile_data.shape[2]] = (
                        tile_data
                    )
                    tile_data = padded_data

                # Convert to tensor

                tile_tensor = torch.from_numpy(tile_data).float() / 255.0
                batch_data.append(tile_tensor)

            # Convert batch to tensor
            batch_tensor = torch.stack(batch_data)

            # Run inference
            with torch.no_grad():
                logits = task.model.predict(batch_tensor.cuda())
                probs = torch.softmax(logits, dim=1)
                confidence, predictions = torch.max(probs, dim=1)
                predictions = predictions.cpu().numpy()
                confidence = confidence.cpu().numpy()

            # Process each prediction
            for idx, (y_start, x_start, y_end, x_end) in enumerate(batch_positions):
                pred = predictions[idx]
                conf = confidence[idx]

                # Calculate actual tile size
                actual_height = y_end - y_start
                actual_width = x_end - x_start

                # Get the actual prediction (removing padding if needed)
                valid_pred = pred[:actual_height, :actual_width]
                valid_conf = conf[:actual_height, :actual_width]

                # Create confidence weights that favor central parts of tiles
                # but still allow edge tiles to contribute fully at the image edges
                is_edge_x = (x_start == 0) or (x_end == width)
                is_edge_y = (y_start == 0) or (y_end == height)

                # Create a mask that gives higher weight to central regions
                # but ensures proper edge handling for boundary tiles
                weight_mask = np.ones((actual_height, actual_width), dtype=np.float32)

                # Only apply central weighting if not at an image edge
                border = overlap // 2
                if not is_edge_x and actual_width > 2 * border:
                    # Apply horizontal edge falloff (linear)
                    for i in range(border):
                        # Left edge
                        weight_mask[:, i] = (i + 1) / (border + 1)
                        # Right edge (if not at image edge)
                        if i < actual_width - border:
                            weight_mask[:, actual_width - i - 1] = (i + 1) / (
                                border + 1
                            )

                if not is_edge_y and actual_height > 2 * border:
                    # Apply vertical edge falloff (linear)
                    for i in range(border):
                        # Top edge
                        weight_mask[i, :] = (i + 1) / (border + 1)
                        # Bottom edge (if not at image edge)
                        if i < actual_height - border:
                            weight_mask[actual_height - i - 1, :] = (i + 1) / (
                                border + 1
                            )

                # Combine with prediction confidence
                final_weight = weight_mask * valid_conf

                # Update the output image based on confidence
                current_conf = confidence_map[y_start:y_end, x_start:x_end]
                update_mask = final_weight > current_conf

                if np.any(update_mask):
                    # Update only pixels where this prediction has higher confidence
                    output_image[y_start:y_end, x_start:x_end][update_mask] = (
                        valid_pred[update_mask]
                    )
                    confidence_map[y_start:y_end, x_start:x_end][update_mask] = (
                        final_weight[update_mask]
                    )

        # Update profile for output
        profile.update({"count": 1, "dtype": "uint8", "nodata": 0})

        # Save the result
        print(f"Saving classified image to {output_path}...")
        with rasterio.open(output_path, "w", **profile) as dst:
            dst.write(output_image[np.newaxis, :, :])
            if isinstance(colormap, dict):
                dst.write_colormap(1, colormap)

        # Calculate timing
        total_time = timeit.default_timer() - start_time
        print(f"Total processing time: {total_time:.2f} seconds")
        print(f"Successfully saved classified image to {output_path}")

    return output_path

classify_images(image_paths, model_path, output_dir=None, chip_size=1024, batch_size=4, colormap=None, file_extension='.tif', **kwargs)

Classify multiple geospatial images using a trained semantic segmentation model.

This function accepts either a list of image paths or a directory containing images and applies the classify_image function to each image, saving the results in the specified output directory.

Parameters:

Name Type Description Default
image_paths str or list

Either a directory path containing images or a list of paths to input GeoTIFF images.

required
model_path str

Path to the trained model checkpoint.

required
output_dir str

Directory to save the output classified images. Defaults to None (same directory as input images for a list, or a new "classified" subdirectory for a directory input).

None
chip_size int

Size of chips for processing. Defaults to 1024.

1024
batch_size int

Batch size for inference. Defaults to 4.

4
colormap dict

Colormap to apply to the output images. Defaults to None.

None
file_extension str

File extension to filter by when image_paths is a directory. Defaults to ".tif".

'.tif'
**kwargs

Additional keyword arguments for the classify_image function.

{}

Returns:

Name Type Description
list

List of paths to the saved classified images.

Source code in geoai/classify.py
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
def classify_images(
    image_paths,
    model_path,
    output_dir=None,
    chip_size=1024,
    batch_size=4,
    colormap=None,
    file_extension=".tif",
    **kwargs,
):
    """
    Classify multiple geospatial images using a trained semantic segmentation model.

    This function accepts either a list of image paths or a directory containing images
    and applies the classify_image function to each image, saving the results in the
    specified output directory.

    Parameters:
        image_paths (str or list): Either a directory path containing images or a list
            of paths to input GeoTIFF images.
        model_path (str): Path to the trained model checkpoint.
        output_dir (str, optional): Directory to save the output classified images.
            Defaults to None (same directory as input images for a list, or a new
            "classified" subdirectory for a directory input).
        chip_size (int, optional): Size of chips for processing. Defaults to 1024.
        batch_size (int, optional): Batch size for inference. Defaults to 4.
        colormap (dict, optional): Colormap to apply to the output images.
            Defaults to None.
        file_extension (str, optional): File extension to filter by when image_paths
            is a directory. Defaults to ".tif".
        **kwargs: Additional keyword arguments for the classify_image function.

    Returns:
        list: List of paths to the saved classified images.
    """
    # Import required libraries
    import glob

    from tqdm import tqdm

    # Process directory input
    if isinstance(image_paths, str) and os.path.isdir(image_paths):
        # Set default output directory if not provided
        if output_dir is None:
            output_dir = os.path.join(image_paths, "classified")

        # Get all images with the specified extension
        image_path_list = glob.glob(os.path.join(image_paths, f"*{file_extension}"))

        # Check if any images were found
        if not image_path_list:
            print(f"No files with extension '{file_extension}' found in {image_paths}")
            return []

        print(f"Found {len(image_path_list)} images in directory {image_paths}")

    # Process list input
    elif isinstance(image_paths, list):
        image_path_list = image_paths

        # Set default output directory if not provided
        if output_dir is None and len(image_path_list) > 0:
            output_dir = os.path.dirname(image_path_list[0])

    # Invalid input
    else:
        raise ValueError(
            "image_paths must be either a directory path or a list of file paths"
        )

    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    classified_image_paths = []

    # Create progress bar
    for image_path in tqdm(image_path_list, desc="Classifying images", unit="image"):
        try:
            # Get just the filename without extension
            base_filename = os.path.splitext(os.path.basename(image_path))[0]

            # Create output path within output_dir
            output_path = os.path.join(
                output_dir, f"{base_filename}_classified{file_extension}"
            )

            # Perform classification
            classified_image_path = classify_image(
                image_path,
                model_path,
                output_path=output_path,
                chip_size=chip_size,
                batch_size=batch_size,
                colormap=colormap,
                **kwargs,
            )
            classified_image_paths.append(classified_image_path)
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")

    print(
        f"Classification complete. Processed {len(classified_image_paths)} images successfully."
    )
    return classified_image_paths

train_classifier(image_root, label_root, output_dir='output', in_channels=4, num_classes=14, epochs=20, img_size=256, batch_size=8, sample_size=500, model='unet', backbone='resnet50', weights=True, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False, transforms=None, use_augmentation=False, seed=42, train_val_test_split=(0.6, 0.2, 0.2), accelerator='auto', devices='auto', logger=None, callbacks=None, log_every_n_steps=10, use_distributed_sampler=False, monitor_metric='val_loss', mode='min', save_top_k=1, save_last=True, checkpoint_filename='best_model', checkpoint_path=None, every_n_epochs=1, **kwargs)

Train a semantic segmentation model on geospatial imagery.

This function sets up datasets, model, trainer, and executes the training process for semantic segmentation tasks using geospatial data. It supports training from scratch or resuming from a checkpoint if available.

Parameters:

Name Type Description Default
image_root str

Path to directory containing imagery.

required
label_root str

Path to directory containing land cover labels.

required
output_dir str

Directory to save model outputs and checkpoints. Defaults to "output".

'output'
in_channels int

Number of input channels in the imagery. Defaults to 4.

4
num_classes int

Number of classes in the segmentation task. Defaults to 14.

14
epochs int

Number of training epochs. Defaults to 20.

20
img_size int

Size of image patches for training. Defaults to 256.

256
batch_size int

Batch size for training. Defaults to 8.

8
sample_size int

Number of samples per epoch. Defaults to 500.

500
model str

Model architecture to use. Defaults to "unet".

'unet'
backbone str

Backbone network for the model. Defaults to "resnet50".

'resnet50'
weights bool

Whether to use pretrained weights. Defaults to True.

True
num_filters int

Number of filters for the model. Defaults to 3.

3
loss str

Loss function to use ('ce', 'jaccard', or 'focal'). Defaults to "ce".

'ce'
class_weights list

Class weights for loss function. Defaults to None.

None
ignore_index int

Index to ignore in loss calculation. Defaults to None.

None
lr float

Learning rate. Defaults to 0.001.

0.001
patience int

Number of epochs with no improvement after which training will stop. Defaults to 10.

10
freeze_backbone bool

Whether to freeze backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze decoder. Defaults to False.

False
transforms callable

Transforms to apply to the data. Defaults to None.

None
use_augmentation bool

Whether to apply data augmentation. Defaults to False.

False
seed int

Random seed for reproducibility. Defaults to 42.

42
train_val_test_split list

Proportions for train/val/test split. Defaults to [0.6, 0.2, 0.2].

(0.6, 0.2, 0.2)
accelerator str

Accelerator to use for training ('cpu', 'gpu', etc.). Defaults to "auto".

'auto'
devices str

Number of devices to use for training. Defaults to "auto".

'auto'
logger object

Logger for tracking training progress. Defaults to None.

None
callbacks list

List of callbacks for the trainer. Defaults to None.

None
log_every_n_steps int

Frequency of logging training progress. Defaults to 10.

10
use_distributed_sampler bool

Whether to use distributed sampling. Defaults to False.

False
monitor_metric str

Metric to monitor for saving best model. Defaults to "val_loss".

'val_loss'
mode str

Mode for monitoring metric ('min' or 'max'). Use 'min' for losses and 'max' for metrics like accuracy. Defaults to "min".

'min'
save_top_k int

Number of best models to save. Defaults to 1.

1
save_last bool

Whether to save the model from the last epoch. Defaults to True.

True
checkpoint_filename str

Filename pattern for saved checkpoints. Defaults to "best_model_{epoch:02d}_{val_loss:.4f}".

'best_model'
checkpoint_path str

Path to a checkpoint file to resume training.

None
every_n_epochs int

Save a checkpoint every N epochs. Defaults to 1.

1
**kwargs

Additional keyword arguments to pass to the datasets.

{}

Returns:

Name Type Description
object

Trained SemanticSegmentationTask model.

Source code in geoai/classify.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def train_classifier(
    image_root,
    label_root,
    output_dir="output",
    in_channels=4,
    num_classes=14,
    epochs=20,
    img_size=256,
    batch_size=8,
    sample_size=500,
    model="unet",
    backbone="resnet50",
    weights=True,
    num_filters=3,
    loss="ce",
    class_weights=None,
    ignore_index=None,
    lr=0.001,
    patience=10,
    freeze_backbone=False,
    freeze_decoder=False,
    transforms=None,
    use_augmentation=False,
    seed=42,
    train_val_test_split=(0.6, 0.2, 0.2),
    accelerator="auto",
    devices="auto",
    logger=None,
    callbacks=None,
    log_every_n_steps=10,
    use_distributed_sampler=False,
    monitor_metric="val_loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    checkpoint_filename="best_model",
    checkpoint_path=None,
    every_n_epochs=1,
    **kwargs,
):
    """Train a semantic segmentation model on geospatial imagery.

    This function sets up datasets, model, trainer, and executes the training process
    for semantic segmentation tasks using geospatial data. It supports training
    from scratch or resuming from a checkpoint if available.

    Args:
        image_root (str): Path to directory containing imagery.
        label_root (str): Path to directory containing land cover labels.
        output_dir (str, optional): Directory to save model outputs and checkpoints.
            Defaults to "output".
        in_channels (int, optional): Number of input channels in the imagery.
            Defaults to 4.
        num_classes (int, optional): Number of classes in the segmentation task.
            Defaults to 14.
        epochs (int, optional): Number of training epochs. Defaults to 20.
        img_size (int, optional): Size of image patches for training. Defaults to 256.
        batch_size (int, optional): Batch size for training. Defaults to 8.
        sample_size (int, optional): Number of samples per epoch. Defaults to 500.
        model (str, optional): Model architecture to use. Defaults to "unet".
        backbone (str, optional): Backbone network for the model. Defaults to "resnet50".
        weights (bool, optional): Whether to use pretrained weights. Defaults to True.
        num_filters (int, optional): Number of filters for the model. Defaults to 3.
        loss (str, optional): Loss function to use ('ce', 'jaccard', or 'focal').
            Defaults to "ce".
        class_weights (list, optional): Class weights for loss function. Defaults to None.
        ignore_index (int, optional): Index to ignore in loss calculation. Defaults to None.
        lr (float, optional): Learning rate. Defaults to 0.001.
        patience (int, optional): Number of epochs with no improvement after which
            training will stop. Defaults to 10.
        freeze_backbone (bool, optional): Whether to freeze backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze decoder. Defaults to False.
        transforms (callable, optional): Transforms to apply to the data. Defaults to None.
        use_augmentation (bool, optional): Whether to apply data augmentation.
            Defaults to False.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        train_val_test_split (list, optional): Proportions for train/val/test split.
            Defaults to [0.6, 0.2, 0.2].
        accelerator (str, optional): Accelerator to use for training ('cpu', 'gpu', etc.).
            Defaults to "auto".
        devices (str, optional): Number of devices to use for training. Defaults to "auto".
        logger (object, optional): Logger for tracking training progress. Defaults to None.
        callbacks (list, optional): List of callbacks for the trainer. Defaults to None.
        log_every_n_steps (int, optional): Frequency of logging training progress.
            Defaults to 10.
        use_distributed_sampler (bool, optional): Whether to use distributed sampling.
            Defaults to False.
        monitor_metric (str, optional): Metric to monitor for saving best model.
            Defaults to "val_loss".
        mode (str, optional): Mode for monitoring metric ('min' or 'max').
            Use 'min' for losses and 'max' for metrics like accuracy.
            Defaults to "min".
        save_top_k (int, optional): Number of best models to save.
            Defaults to 1.
        save_last (bool, optional): Whether to save the model from the last epoch.
            Defaults to True.
        checkpoint_filename (str, optional): Filename pattern for saved checkpoints.
            Defaults to "best_model_{epoch:02d}_{val_loss:.4f}".
        checkpoint_path (str, optional): Path to a checkpoint file to resume training.
        every_n_epochs (int, optional): Save a checkpoint every N epochs.
            Defaults to 1.
        **kwargs: Additional keyword arguments to pass to the datasets.

    Returns:
        object: Trained SemanticSegmentationTask model.
    """
    import multiprocessing as mp
    import timeit

    import albumentations as A
    import lightning.pytorch as pl
    import torch
    from lightning.pytorch.callbacks import ModelCheckpoint
    from lightning.pytorch.loggers import CSVLogger
    from torch.utils.data import DataLoader
    from torchgeo.datamodules import GeoDataModule
    from torchgeo.datasets import RasterDataset, stack_samples
    from torchgeo.datasets.splits import random_bbox_assignment
    from torchgeo.samplers import (
        GridGeoSampler,
        RandomBatchGeoSampler,
        RandomGeoSampler,
    )
    from torchgeo.trainers import SemanticSegmentationTask

    # Create a wrapper class for albumentations to work with TorchGeo format
    class AlbumentationsWrapper:
        def __init__(self, transform):
            self.transform = transform

        def __call__(self, sample):
            # Extract image and mask from TorchGeo sample format
            if "image" not in sample or "mask" not in sample:
                return sample

            image = sample["image"]
            mask = sample["mask"]

            # Albumentations expects channels last, but TorchGeo uses channels first
            # Convert (C, H, W) to (H, W, C) for image
            image_np = image.permute(1, 2, 0).numpy()
            mask_np = mask.squeeze(0).numpy() if mask.dim() > 2 else mask.numpy()

            # Apply transformation with named arguments
            transformed = self.transform(image=image_np, mask=mask_np)

            # Convert back to PyTorch tensors with channels first
            transformed_image = torch.from_numpy(transformed["image"]).permute(2, 0, 1)
            transformed_mask = torch.from_numpy(transformed["mask"]).unsqueeze(0)

            # Update the sample dictionary
            result = sample.copy()
            result["image"] = transformed_image
            result["mask"] = transformed_mask

            return result

    # Set up data augmentation if requested
    if use_augmentation:
        aug_transforms = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(
                    p=0.5, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45
                ),
                A.RandomBrightnessContrast(
                    p=0.5, brightness_limit=0.2, contrast_limit=0.2
                ),
                A.GaussianBlur(p=0.3),
                A.GaussNoise(p=0.3),
                A.CoarseDropout(p=0.3, max_holes=8, max_height=32, max_width=32),
            ]
        )
        # Wrap the albumentations transforms
        transforms = AlbumentationsWrapper(aug_transforms)

    # # Set up device configuration
    # device, num_devices = (
    #     ("cuda", torch.cuda.device_count())
    #     if torch.cuda.is_available()
    #     else ("cpu", mp.cpu_count())
    # )
    workers = mp.cpu_count()
    # print(f"Running on {num_devices} {device}(s)")

    # Define datasets
    class ImageDatasetClass(RasterDataset):
        filename_glob = "*.tif"
        is_image = True
        separate_files = False

    class LabelDatasetClass(RasterDataset):
        filename_glob = "*.tif"
        is_image = False
        separate_files = False

    # Prepare output directory
    test_dir = os.path.join(output_dir, "models")
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)

    # Set up logger and checkpoint callback
    if logger is None:
        logger = CSVLogger(test_dir, name="lightning_logs")

    if callbacks is None:
        checkpoint_callback = ModelCheckpoint(
            dirpath=test_dir,
            filename=checkpoint_filename,
            save_top_k=save_top_k,
            monitor=monitor_metric,
            mode=mode,
            save_last=save_last,
            every_n_epochs=every_n_epochs,
            verbose=True,
        )
        callbacks = [checkpoint_callback]

    # Initialize the segmentation task
    task = SemanticSegmentationTask(
        model=model,
        backbone=backbone,
        weights=weights,
        in_channels=in_channels,
        num_classes=num_classes,
        num_filters=num_filters,
        loss=loss,
        class_weights=class_weights,
        ignore_index=ignore_index,
        lr=lr,
        patience=patience,
        freeze_backbone=freeze_backbone,
        freeze_decoder=freeze_decoder,
    )

    # Set up trainer
    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=epochs,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=log_every_n_steps,
        use_distributed_sampler=use_distributed_sampler,
        **kwargs,  # Pass any additional kwargs to the trainer
    )

    # Load datasets with transforms if augmentation is enabled

    if isinstance(image_root, RasterDataset):
        images = image_root
    else:
        images = ImageDatasetClass(paths=image_root, transforms=transforms, **kwargs)

    if isinstance(label_root, RasterDataset):
        labels = label_root
    else:
        labels = LabelDatasetClass(paths=label_root, **kwargs)

    # Create intersection dataset
    dataset = images & labels

    # Define custom datamodule for training
    class CustomGeoDataModule(GeoDataModule):
        def setup(self, stage: str) -> None:
            """Set up datasets.

            Args:
                stage: Either 'fit', 'validate', 'test', or 'predict'.
            """
            self.dataset = self.dataset_class(**self.kwargs)

            generator = torch.Generator().manual_seed(seed)
            (
                self.train_dataset,
                self.val_dataset,
                self.test_dataset,
            ) = random_bbox_assignment(dataset, train_val_test_split, generator)

            if stage in ["fit"]:
                self.train_batch_sampler = RandomBatchGeoSampler(
                    self.train_dataset, self.patch_size, self.batch_size, self.length
                )
            if stage in ["fit", "validate"]:
                self.val_sampler = GridGeoSampler(
                    self.val_dataset, self.patch_size, self.patch_size
                )
            if stage in ["test"]:
                self.test_sampler = GridGeoSampler(
                    self.test_dataset, self.patch_size, self.patch_size
                )

    # Create datamodule
    datamodule = CustomGeoDataModule(
        dataset_class=type(dataset),
        batch_size=batch_size,
        patch_size=img_size,
        length=sample_size,
        num_workers=workers,
        dataset1=images,
        dataset2=labels,
        collate_fn=stack_samples,
    )

    # Start training timer
    start = timeit.default_timer()

    # Check for existing checkpoint
    if checkpoint_path is not None:
        checkpoint_file = os.path.abspath(checkpoint_path)
    else:
        checkpoint_file = os.path.join(test_dir, "last.ckpt")

    if os.path.isfile(checkpoint_file):
        print("Resuming training from previous checkpoint...")
        trainer.fit(model=task, datamodule=datamodule, ckpt_path=checkpoint_file)
    else:
        print("Starting training from scratch...")
        trainer.fit(
            model=task,
            datamodule=datamodule,
        )

    training_time = timeit.default_timer() - start
    print(f"The time taken to train was: {training_time:.2f} seconds")

    best_model_path = checkpoint_callback.best_model_path
    print(f"Best model saved at: {best_model_path}")

    # Test the model
    trainer.test(model=task, datamodule=datamodule)

    return task