Skip to content

hf module

This module contains utility functions for working with Hugging Face models.

get_model_config(model_id)

Get the model configuration for a Hugging Face model.

Parameters:

Name Type Description Default
model_id str

The Hugging Face model ID.

required

Returns:

Type Description

transformers.configuration_utils.PretrainedConfig: The model configuration.

Source code in geoai/hf.py
15
16
17
18
19
20
21
22
23
24
25
def get_model_config(model_id):
    """
    Get the model configuration for a Hugging Face model.

    Args:
        model_id (str): The Hugging Face model ID.

    Returns:
        transformers.configuration_utils.PretrainedConfig: The model configuration.
    """
    return AutoConfig.from_pretrained(model_id)

get_model_input_channels(model_id)

Check the number of input channels supported by a Hugging Face model.

Parameters:

Name Type Description Default
model_id str

The Hugging Face model ID.

required

Returns:

Name Type Description
int

The number of input channels the model accepts.

Raises:

Type Description
ValueError

If unable to determine the number of input channels.

Source code in geoai/hf.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_model_input_channels(model_id):
    """
    Check the number of input channels supported by a Hugging Face model.

    Args:
        model_id (str): The Hugging Face model ID.

    Returns:
        int: The number of input channels the model accepts.

    Raises:
        ValueError: If unable to determine the number of input channels.
    """
    # Load the model configuration
    config = AutoConfig.from_pretrained(model_id)

    # For Mask2Former models
    if hasattr(config, "backbone_config"):
        if hasattr(config.backbone_config, "num_channels"):
            return config.backbone_config.num_channels

    # Try to load the model and inspect its architecture
    try:
        model = AutoModelForMaskedImageModeling.from_pretrained(model_id)

        # For Swin Transformer-based models like Mask2Former
        if hasattr(model, "backbone") and hasattr(model.backbone, "embeddings"):
            if hasattr(model.backbone.embeddings, "patch_embeddings"):
                # Swin models typically have patch embeddings that indicate channel count
                return model.backbone.embeddings.patch_embeddings.in_channels
    except Exception as e:
        print(f"Couldn't inspect model architecture: {e}")

    # Default for most vision models
    return 3

image_segmentation(tif_path, output_path, labels_to_extract=None, dtype='uint8', model_name=None, segmenter_args=None, **kwargs)

Segments an image with a Hugging Face segmentation model and saves the results as a single georeferenced image where each class has a unique integer value.

Parameters:

Name Type Description Default
tif_path str

Path to the input georeferenced TIF file.

required
output_path str

Path where the output georeferenced segmentation will be saved.

required
labels_to_extract list

List of labels to extract. If None, extracts all labels.

None
dtype str

Data type to use for the output mask. Defaults to "uint8".

'uint8'
model_name str

Name of the Hugging Face model to use for segmentation, such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None. See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.

None
segmenter_args dict

Additional arguments to pass to the segmenter. Defaults to None.

None
**kwargs

Additional keyword arguments to pass to the segmentation pipeline

{}

Returns:

Name Type Description
tuple

(Path to saved image, dictionary mapping label names to their assigned values, dictionary mapping label names to confidence scores)

Source code in geoai/hf.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def image_segmentation(
    tif_path,
    output_path,
    labels_to_extract=None,
    dtype="uint8",
    model_name=None,
    segmenter_args=None,
    **kwargs,
):
    """
    Segments an image with a Hugging Face segmentation model and saves the results
    as a single georeferenced image where each class has a unique integer value.

    Args:
        tif_path (str): Path to the input georeferenced TIF file.
        output_path (str): Path where the output georeferenced segmentation will be saved.
        labels_to_extract (list, optional): List of labels to extract. If None, extracts all labels.
        dtype (str, optional): Data type to use for the output mask. Defaults to "uint8".
        model_name (str, optional): Name of the Hugging Face model to use for segmentation,
            such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None.
            See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.
        segmenter_args (dict, optional): Additional arguments to pass to the segmenter.
            Defaults to None.
        **kwargs: Additional keyword arguments to pass to the segmentation pipeline

    Returns:
        tuple: (Path to saved image, dictionary mapping label names to their assigned values,
            dictionary mapping label names to confidence scores)
    """
    # Load the original georeferenced image to extract metadata
    with rasterio.open(tif_path) as src:
        # Save the metadata for later use
        meta = src.meta.copy()
        # Get the dimensions
        height = src.height
        width = src.width
        # Get the transform and CRS for georeferencing
        # transform = src.transform
        # crs = src.crs

    # Initialize the segmentation pipeline
    if model_name is None:
        model_name = "facebook/mask2former-swin-large-cityscapes-semantic"

    kwargs["task"] = "image-segmentation"

    segmenter = pipeline(model=model_name, **kwargs)

    # Run the segmentation on the GeoTIFF
    if segmenter_args is None:
        segmenter_args = {}

    segments = segmenter(tif_path, **segmenter_args)

    # If no specific labels are requested, extract all available ones
    if labels_to_extract is None:
        labels_to_extract = [segment["label"] for segment in segments]

    # Create an empty mask to hold all the labels
    # Using uint8 for up to 255 classes, switch to uint16 for more
    combined_mask = np.zeros((height, width), dtype=np.uint8)

    # Create a dictionary to map labels to values and store scores
    label_to_value = {}
    label_to_score = {}

    # Process each segment we want to keep
    for i, segment in enumerate(
        [s for s in segments if s["label"] in labels_to_extract]
    ):
        # Assign a unique value to each label (starting from 1)
        value = i + 1
        label = segment["label"]
        score = segment["score"]

        label_to_value[label] = value
        label_to_score[label] = score

        # Convert PIL image to numpy array
        mask = np.array(segment["mask"])

        # Apply a threshold if it's a probability mask (not binary)
        if mask.dtype == float:
            mask = (mask > 0.5).astype(np.uint8)

        # Resize if needed to match original dimensions
        if mask.shape != (height, width):
            mask_img = Image.fromarray(mask)
            mask_img = mask_img.resize((width, height))
            mask = np.array(mask_img)

        # Add this class to the combined mask
        # Only overwrite if the pixel isn't already assigned to another class
        # This handles overlapping segments by giving priority to earlier segments
        combined_mask = np.where(
            (mask > 0) & (combined_mask == 0), value, combined_mask
        )

    # Update metadata for the output raster
    meta.update(
        {
            "count": 1,  # One band for the mask
            "dtype": dtype,  # Use uint8 for up to 255 classes
            "nodata": 0,  # 0 represents no class
        }
    )

    # Save the mask as a new georeferenced GeoTIFF
    with rasterio.open(output_path, "w", **meta) as dst:
        dst.write(combined_mask[np.newaxis, :, :])  # Add channel dimension

    # Create a CSV colormap file with scores included
    csv_path = os.path.splitext(output_path)[0] + "_colormap.csv"
    with open(csv_path, "w", newline="") as csvfile:
        fieldnames = ["ClassValue", "ClassName", "ConfidenceScore"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for label, value in label_to_value.items():
            writer.writerow(
                {
                    "ClassValue": value,
                    "ClassName": label,
                    "ConfidenceScore": f"{label_to_score[label]:.4f}",
                }
            )

    return output_path, label_to_value, label_to_score

mask_generation(input_path, output_mask_path, output_csv_path, model='facebook/sam-vit-base', confidence_threshold=0.5, points_per_side=32, crop_size=None, batch_size=1, band_indices=None, min_object_size=0, generator_kwargs=None, **kwargs)

Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.

The function reads a GeoTIFF image, applies the SAM mask generator from the Hugging Face transformers pipeline, rasterizes the resulting masks to create a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.

Parameters:

Name Type Description Default
input_path str

Path to the input GeoTIFF image.

required
output_mask_path str

Path where the output mask GeoTIFF will be saved.

required
output_csv_path str

Path where the mask scores CSV will be saved.

required
model str

HuggingFace model checkpoint for the SAM model.

'facebook/sam-vit-base'
confidence_threshold float

Minimum confidence score for masks to be included.

0.5
points_per_side int

Number of points to sample along each side of the image.

32
crop_size Optional[int]

Size of image crops for processing. If None, process the full image.

None
band_indices Optional[List[int]]

List of band indices to use. If None, use all bands.

None
batch_size int

Batch size for inference.

1
min_object_size int

Minimum size in pixels for objects to be included. Smaller masks will be filtered out.

0
generator_kwargs Optional[Dict]

Additional keyword arguments to pass to the mask generator.

None

Returns:

Type Description
Tuple[str, str]

Tuple containing the paths to the saved mask GeoTIFF and CSV file.

Raises:

Type Description
ValueError

If the input file cannot be opened or processed.

RuntimeError

If mask generation fails.

Source code in geoai/hf.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def mask_generation(
    input_path: str,
    output_mask_path: str,
    output_csv_path: str,
    model: str = "facebook/sam-vit-base",
    confidence_threshold: float = 0.5,
    points_per_side: int = 32,
    crop_size: Optional[int] = None,
    batch_size: int = 1,
    band_indices: Optional[List[int]] = None,
    min_object_size: int = 0,
    generator_kwargs: Optional[Dict] = None,
    **kwargs,
) -> Tuple[str, str]:
    """
    Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.

    The function reads a GeoTIFF image, applies the SAM mask generator from the
    Hugging Face transformers pipeline, rasterizes the resulting masks to create
    a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.

    Args:
        input_path: Path to the input GeoTIFF image.
        output_mask_path: Path where the output mask GeoTIFF will be saved.
        output_csv_path: Path where the mask scores CSV will be saved.
        model: HuggingFace model checkpoint for the SAM model.
        confidence_threshold: Minimum confidence score for masks to be included.
        points_per_side: Number of points to sample along each side of the image.
        crop_size: Size of image crops for processing. If None, process the full image.
        band_indices: List of band indices to use. If None, use all bands.
        batch_size: Batch size for inference.
        min_object_size: Minimum size in pixels for objects to be included. Smaller masks will be filtered out.
        generator_kwargs: Additional keyword arguments to pass to the mask generator.

    Returns:
        Tuple containing the paths to the saved mask GeoTIFF and CSV file.

    Raises:
        ValueError: If the input file cannot be opened or processed.
        RuntimeError: If mask generation fails.
    """
    # Set up the mask generator
    print("Setting up mask generator...")
    mask_generator = pipeline(model=model, task="mask-generation", **kwargs)

    # Open the GeoTIFF file
    try:
        print(f"Reading input GeoTIFF: {input_path}")
        with rasterio.open(input_path) as src:
            # Read metadata
            profile = src.profile
            # transform = src.transform
            # crs = src.crs

            # Read the image data
            if band_indices is not None:
                print(f"Using specified bands: {band_indices}")
                image_data = np.stack([src.read(i + 1) for i in band_indices])
            else:
                print("Using all bands")
                image_data = src.read()

            # Handle image with more than 3 bands (convert to RGB for visualization)
            if image_data.shape[0] > 3:
                print(
                    f"Converting {image_data.shape[0]} bands to RGB (using first 3 bands)"
                )
                # Select first three bands or perform other band combination
                image_data = image_data[:3]
            elif image_data.shape[0] == 1:
                print("Duplicating single band to create 3-band image")
                # Duplicate single band to create a 3-band image
                image_data = np.vstack([image_data] * 3)

            # Transpose to HWC format for the model
            image_data = np.transpose(image_data, (1, 2, 0))

            # Normalize the image if needed
            if image_data.dtype != np.uint8:
                print(f"Normalizing image from {image_data.dtype} to uint8")
                image_data = (image_data / image_data.max() * 255).astype(np.uint8)
    except Exception as e:
        raise ValueError(f"Failed to open or process input GeoTIFF: {e}")

    # Process the image with the mask generator
    try:
        # Convert numpy array to PIL Image for the pipeline
        # Ensure the array is in the right format (HWC and uint8)
        if image_data.dtype != np.uint8:
            image_data = (image_data / image_data.max() * 255).astype(np.uint8)

        # Create a PIL Image from the numpy array
        print("Converting to PIL Image for mask generation")
        pil_image = Image.fromarray(image_data)

        # Use the SAM pipeline for mask generation
        if generator_kwargs is None:
            generator_kwargs = {}

        print("Running mask generation...")
        mask_results = mask_generator(
            pil_image,
            points_per_side=points_per_side,
            crop_n_points_downscale_factor=1 if crop_size is None else 2,
            point_grids=None,
            pred_iou_thresh=confidence_threshold,
            stability_score_thresh=confidence_threshold,
            crops_n_layers=0 if crop_size is None else 1,
            crop_overlap_ratio=0.5,
            batch_size=batch_size,
            **generator_kwargs,
        )

        print(
            f"Number of initial masks: {len(mask_results['masks']) if isinstance(mask_results, dict) and 'masks' in mask_results else len(mask_results)}"
        )

    except Exception as e:
        raise RuntimeError(f"Mask generation failed: {e}")

    # Create a mask raster with unique IDs for each mask
    mask_raster = np.zeros((image_data.shape[0], image_data.shape[1]), dtype=np.uint32)
    mask_records = []

    # Process each mask based on the structure of mask_results
    if (
        isinstance(mask_results, dict)
        and "masks" in mask_results
        and "scores" in mask_results
    ):
        # Handle dictionary with 'masks' and 'scores' lists
        print("Processing masks...")
        total_masks = len(mask_results["masks"])

        # Create progress bar
        for i, (mask_data, score) in enumerate(
            tqdm(
                zip(mask_results["masks"], mask_results["scores"]),
                total=total_masks,
                desc="Processing masks",
            )
        ):
            mask_id = i + 1  # Start IDs at 1

            # Convert to numpy if not already
            if not isinstance(mask_data, np.ndarray):
                # Try to convert from tensor or other format if needed
                try:
                    mask_data = np.array(mask_data)
                except:
                    print(f"Could not convert mask at index {i} to numpy array")
                    continue

            mask_binary = mask_data.astype(bool)
            area_pixels = np.sum(mask_binary)

            # Skip if mask is smaller than the minimum size
            if area_pixels < min_object_size:
                continue

            # Add the mask to the raster with a unique ID
            mask_raster[mask_binary] = mask_id

            # Create a record for the CSV - without geometry calculation
            mask_records.append(
                {"mask_id": mask_id, "score": float(score), "area_pixels": area_pixels}
            )
    elif isinstance(mask_results, list):
        # Handle list of dictionaries format (SAM original format)
        print("Processing masks...")
        total_masks = len(mask_results)

        # Create progress bar
        for i, mask_result in enumerate(tqdm(mask_results, desc="Processing masks")):
            mask_id = i + 1  # Start IDs at 1

            # Try different possible key names for masks and scores
            mask_data = None
            score = None

            if isinstance(mask_result, dict):
                # Try to find mask data
                if "segmentation" in mask_result:
                    mask_data = mask_result["segmentation"]
                elif "mask" in mask_result:
                    mask_data = mask_result["mask"]

                # Try to find score
                if "score" in mask_result:
                    score = mask_result["score"]
                elif "predicted_iou" in mask_result:
                    score = mask_result["predicted_iou"]
                elif "stability_score" in mask_result:
                    score = mask_result["stability_score"]
                else:
                    score = 1.0  # Default score if none found
            else:
                # If mask_result is not a dict, it might be the mask directly
                try:
                    mask_data = np.array(mask_result)
                    score = 1.0  # Default score
                except:
                    print(f"Could not process mask at index {i}")
                    continue

            if mask_data is not None:
                # Convert to numpy if not already
                if not isinstance(mask_data, np.ndarray):
                    try:
                        mask_data = np.array(mask_data)
                    except:
                        print(f"Could not convert mask at index {i} to numpy array")
                        continue

                mask_binary = mask_data.astype(bool)
                area_pixels = np.sum(mask_binary)

                # Skip if mask is smaller than the minimum size
                if area_pixels < min_object_size:
                    continue

                # Add the mask to the raster with a unique ID
                mask_raster[mask_binary] = mask_id

                # Create a record for the CSV - without geometry calculation
                mask_records.append(
                    {
                        "mask_id": mask_id,
                        "score": float(score),
                        "area_pixels": area_pixels,
                    }
                )
    else:
        # If we couldn't figure out the format, raise an error
        raise ValueError(f"Unexpected format for mask_results: {type(mask_results)}")

    print(f"Number of final masks (after size filtering): {len(mask_records)}")

    # Save the mask raster as a GeoTIFF
    print(f"Saving mask GeoTIFF to {output_mask_path}")
    output_profile = profile.copy()
    output_profile.update(dtype=rasterio.uint32, count=1, compress="lzw", nodata=0)

    with rasterio.open(output_mask_path, "w", **output_profile) as dst:
        dst.write(mask_raster.astype(rasterio.uint32), 1)

    # Save the mask data as a CSV
    print(f"Saving mask metadata to {output_csv_path}")
    mask_df = pd.DataFrame(mask_records)
    mask_df.to_csv(output_csv_path, index=False)

    print("Processing complete!")
    return output_mask_path, output_csv_path