4  Clay Deforestation Segmentation Model

Authors
Affiliation

Ate Poortinga

Spatial Informatics Group

Daniel Marc dela Torre

Spatial Informatics Group

Chanarun Saisaward

Spatial Informatics Group

M Warizmi Wafiq

Spatial Informatics Group

Ponlawat Weerapanpisit

Spatial Informatics Group

Vanna Teck

Spatial Informatics Group

Weraphong Suaruang

Spatial Informatics Group

Wipawinee Khamnoi

Spatial Informatics Group

How to cite this chapter:

Poortinga, A., dela Torre, D. M., Saisaward, C., Wafiq, M. W., Weerapanpisit, P., Teck, V., Suaruang, W., & Khamnoi, W. (2026). Clay Deforestation Segmentation Model. Zenodo. https://doi.org/10.5281/zenodo.20547805 DOI

Part of: Mayer, T., Bhandari, B., & Saah, D. (2026). EarthRISE Applied Artificial Intelligence and Deep Learning Book. Zenodo. https://doi.org/10.5281/zenodo.20547797 DOI

Open in Colab Run in Colab View on GitHub View on GitHub

4.0.1 Course Project - Deforestation Analysis

This Colab notebook demonstrates the training and evaluation of a segmentation model for deforestation detection using Sentinel-2 remote sensing data. The workflow pairs Earth Engine (Gorelick et al. 2017) derived image chips and user-provided deforestation labels with a Clay-based Vision Transformer encoder to perform semantic segmentation of forest loss.

4.1 🌍 Project Context

Deforestation is a critical environmental challenge impacting biodiversity, climate, and local ecosystems. Monitoring deforestation over time is essential for sustainable land management and conservation efforts.

Deforestation monitoring is difficult in operational settings because disturbance signals can be subtle, intermittent, and easily confused with phenology, forest degradation, or partial canopy loss. Time-series approaches help address these challenges but require careful parameterization, sustained access to Earth observation data, and practical decisions about thresholds and interpretation in complex landscapes (Aryal et al., 2021). In tropical regions, persistent cloud cover further complicates optical monitoring and motivates the use of SAR data and scalable processing environments; recent work has shown that semantic segmentation implemented can support near real-time forest disturbance detection at scale (Kilbride et al., 2023).

In this project, we develop a deep learning model to detect deforestation from satellite imagery using the Clay model, a state-of-the-art segmentation technique.

4.2 🎯 Objectives:

  1. Data Preprocessing: Convert geospatial data into model-ready formats.
  2. Model Training: Train the CLAY segmentation model to detect changes in forest cover.
  3. Model Evaluation: Assess performance using metrics like IoU, F1 Score, Precision, and Recall.
  4. Inference: Predict deforestation on new satellite imagery.

4.3 💻 Libraries and Tools:

  • PyTorch: Deep learning framework for model training.
  • Lightning: Simplifies the training loop and model management.
  • Earth Engine API: Access to satellite imagery for data extraction.
  • Rasterio: Reading and writing geospatial data (GeoTIFF).
  • Segmentation Models PyTorch: Advanced loss functions for segmentation.
  • Other Utilities: NumPy, SciPy, YAML, Box, Matplotlib.

4.3.1 ✅ Workflow:

  1. Set up the environment and install libraries.
  2. Download satellite data from Google Earth Engine (GEE).
  3. Train the CLAY model to detect deforestation.
  4. Validate the model on test data.
  5. Run inference on new images and visualize the results.

4.3.2 Clay Foundation Model

Clay is a foundation model of Earth. Foundation models trained on Earth observation (EO) data efficiently distill and synthesize vast amounts of environmental information, allowing them to generalize this knowledge to specific downstream applications. This makes them versatile and powerful tools for nature and climate use cases.

Clay is an open-source foundation model for Earth observation that uses Vision Transformer (ViT) architectures to learn general-purpose representations from multi-sensor satellite imagery. The model takes satellite imagery together with spatial (location) and temporal (time) information as input and outputs embeddings, which are mathematical representations of a given area at a specific time on Earth’s surface.

Clay is trained using self-supervised learning (SSL) with a Masked Autoencoder (MAE) approach and is designed to be paired with downstream task-specific components, enabling applications such as semantic segmentation, classification, and change detection using user-provided training data.

The project is developed as an open research and engineering effort and is publicly available, including model code, configuration files, and pretrained checkpoints: GitHub: https://github.com/Clay-foundation/model

Web: https://www.madewithclay.org

4.4 Computational Requirements

Training and fine-tuning the Clay Vision Transformer encoder for semantic segmentation is computationally intensive. A CUDA-enabled GPU is strongly recommended. In practice, a GPU with at least 16 GB of VRAM is required to run the training workflow within a reasonable time and without out-of-memory errors. Running the notebook on a CPU-only machine is not recommended for model training and should be limited to code inspection or lightweight testing.

When using Google Colab, ensure that a GPU runtime is selected (Runtime → Change runtime type → Hardware accelerator → GPU) before executing the training steps.

Let’s get started! 🚀

5 Step 1: Project Setup

In this step, we will set up the environment by installing the necessary libraries.
These libraries are essential for data handling, model training, and interacting with Google Earth Engine (GEE).

5.1 Why These Libraries?

  • NumPy: Efficient numerical operations.
  • Rasterio: Reading and writing GeoTIFF files.
  • Matplotlib: Visualization of image chips.
  • Google Auth & Earth Engine API: Access and authenticate with Google Earth Engine.
  • Requests: Downloading data from GEE.
  • GDAL: Handling geospatial data formats (GeoTIFF).

By installing these libraries now, we ensure the environment is properly configured before data processing.

Show code
# Install necessary libraries for data retrieval and processing
!pip install numpy rasterio matplotlib google-auth earthengine-api requests
!apt-get install -y python3-gdal
!pip -q install vit-pytorch

print("✅ Libraries installed successfully!")

6 Step 2: Mounting Google Drive

To store the downloaded data persistently, we will mount Google Drive.
This allows us to save the downloaded image chips directly to a Drive folder.

6.1 Why Use Google Drive?

  • Persistent storage between Colab sessions.
  • Easy access and sharing of data.
  • Organized structure for training and validation datasets.

6.1.1 Instructions:

  1. Authorize Google Drive access when prompted.
  2. Ensure the folder structure exists in your Drive:
    • MyDrive/Deforestation_Project/data/
Show code
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')
Show code
# Set the base directory for storing data in Google Drive
base_dir = "/content/drive/MyDrive/Deforestation_Project/data"

# Create necessary directories for train and validation data
for split in ["train", "val"]:
    for sub in ["before", "after", "label"]:
        os.makedirs(f"{base_dir}/{split}/{sub}", exist_ok=True)

print("✅ Google Drive mounted and directories created!")

7 Step 3: Authenticating and Initializing Google Earth Engine (GEE)

Google Earth Engine (GEE) is a powerful platform that provides access to vast amounts of satellite imagery.
In this step, we will authenticate and initialize GEE to prepare for data retrieval.

7.1 Why Use GEE?

  • Access to high-resolution, multi-temporal satellite imagery.
  • Efficient data processing through cloud-based computation.
  • Ideal for large-scale environmental monitoring tasks.

7.1.1 Instructions:

  1. Authenticate with your Google account when prompted.
  2. Follow the link, grant permissions, and paste the code into the prompt.

7.2 Setting Up Your Google Earth Engine (GEE) Project

In this step, we will initialize Google Earth Engine (GEE) using your personal project name.
The project name is required to properly authenticate and access your GEE resources.

7.3 🔑 Important:

  • Replace yourprojectname with your own GEE project name.
  • You can find your project name in the Google Cloud Console under IAM & Admin > Settings.
  • The project name typically looks like your-project-id.

7.3.1 Instructions:

  1. Locate your project ID in the Google Cloud Console.
  2. Replace the projectname variable in the code below.
  3. Run the cell to authenticate and initialize Earth Engine.

Please note: the Google Earth Engine project must be properly registered (commercial or non-commercial). If the required project registration information is not provided, authentication will fail and the Colab notebook will not be able to connect to Earth Engine.

Show code
projectname = "update-project-name-here"
Show code
import ee
import google.auth



def ee_init():
    """Authenticate and initialize Earth Engine with the high-volume endpoint."""
    credentials, project = google.auth.default(
        scopes=[
            "https://www.googleapis.com/auth/cloud-platform",
            "https://www.googleapis.com/auth/earthengine",
        ]
    )
    ee.Initialize(project=projectname)

# Authenticate and initialize Earth Engine
ee.Authenticate()
ee_init()
print("✅ Earth Engine initialized successfully!")

““” # Step 4: Downloading and Storing Data from GEE

This workflow uses Sentinel-2 multispectral optical imagery as the primary Earth observation input for deforestation segmentation. Sentinel-2 provides global coverage at 10–20 m spatial resolution with a high revisit frequency, enabling consistent monitoring of forest cover change over time. The multispectral bands capture vegetation structure and condition through visible and near-infrared reflectance, which are well suited for distinguishing intact forest from cleared or disturbed areas.

Sentinel-2 is selected for this task because it offers an effective balance between spatial resolution, temporal coverage, and open data access. Compared to very high–resolution commercial imagery, Sentinel-2 enables scalable, repeatable analysis over large regions without licensing constraints. Compared to coarser-resolution sensors, it preserves sufficient spatial detail to support pixel-level semantic segmentation of deforestation patterns, making it a practical choice for both research and operational monitoring.

In this step, we will download satellite image chips directly from Google Earth Engine (GEE) and save them to Google Drive.
The data will include:
- Before images (GeoTIFF)
- After images (GeoTIFF)
- Deforestation labels (GeoTIFF)

7.4 Why Save as GeoTIFF?

  • Efficient storage format for geospatial data.
  • Easily readable by GIS software and machine learning pipelines.
  • Retains georeferencing information.

7.4.1 Steps:

  1. Define the data download function.
  2. Save the images in an organized folder structure in Google Drive.
  3. Print confirmation of successful downloads.

7.4.2 notes:

The servir-ee Google Earth Engine assets project is publicly available. Because of this, users can directly access the datasets and download the required data to their own Google Drive without requesting additional permissions.

Execution time note: downloading the data may take several minutes depending on connection speed and system load. In our case, this step took approximately 4 minutes to complete. This runtime is expected and has been retained in the notebook to reflect real execution conditions.”

Show code
import os
import io
import requests
import numpy as np
import rasterio
from rasterio.transform import from_origin
from numpy.lib.recfunctions import structured_to_unstructured
import ee


def write_geotiff(
    raster: np.ndarray,
    out_file: str,
    origin_xy: tuple[float, float],  # (xmin, ymax) in IMAGE CRS units (top-left corner)
    img,                             # ee.Image
    patch_size: int
) -> None:
    """
    Saves a raster (single or multi-band numpy array) as a GeoTIFF with correct
    georeferencing using Earth Engine image metadata, while keeping raster
    dimensions fixed to patch_size.

    Args:
        raster (np.ndarray): Raster data, shape (H, W) or (H, W, bands).
        out_file (str): Output file path for GeoTIFF.
        origin_xy (tuple[float, float]): (xmin, ymax) of the top-left pixel in the image CRS.
        img (ee.Image): Earth Engine image for projection and scale info.
        patch_size (int): Fixed raster width & height in pixels.

    Notes:
        - Raster shape is forced/validated to (patch_size, patch_size).
        - CRS (WKT) and nominal pixel scale are taken from img.projection().
    """
    xmin, ymax = origin_xy

    # Get projection (WKT) + nominal scale from EE image
    proj_info = img.projection().getInfo()
    if "wkt" not in proj_info:
        raise RuntimeError("EE projection info missing WKT; cannot set GeoTIFF projection.")
    scale = img.projection().nominalScale().getInfo()  # pixel size in CRS units

    # Force 3D (H, W, bands)
    if raster.ndim == 2:
        raster = raster[:, :, np.newaxis]

    # Convert to float32 for GDAL compatibility
    raster = raster.astype(np.float32)

    # Validate expected dimensions
    if raster.shape[0] != patch_size or raster.shape[1] != patch_size:
        raise ValueError(f"Raster shape {raster.shape[:2]} != ({patch_size}, {patch_size})")

    _, _, bands = raster.shape

    driver = gdal.GetDriverByName("GTiff")
    out_raster = driver.Create(out_file, patch_size, patch_size, bands, gdal.GDT_Float32)
    if out_raster is None:
        raise RuntimeError(f"Failed to create GeoTIFF at {out_file}")

    # GeoTransform: top-left x, pixel width, rot, top-left y, rot, pixel height (negative)
    geotransform = [xmin, scale, 0, ymax, 0, -scale]
    out_raster.SetGeoTransform(geotransform)

    # Set projection WKT from EE
    out_raster.SetProjection(proj_info["wkt"])

    # Write each band
    for i in range(bands):
        band = out_raster.GetRasterBand(i + 1)
        band.WriteArray(raster[:, :, i])

    out_raster = None  # flush + close
    print(f"GeoTIFF saved at {out_file}")


def download_chip_by_index(i: int, split: str = "train", patch_size: int = 512, scale: int = 10) -> None:
    """
    Downloads "before", "after", and "label" image chips from Google Earth Engine
    for a given index, and saves them as GeoTIFFs with correct georeferencing.
    """
    chip_id = f"{i:04d}"
    before_path = f"{base_dir}/{split}/before/{chip_id}.tif"
    after_path  = f"{base_dir}/{split}/after/{chip_id}.tif"
    label_path  = f"{base_dir}/{split}/label/{chip_id}.tif"

    if all(os.path.exists(p) for p in [before_path, after_path, label_path]):
        print(f"Skipped {split}/{chip_id} (already exists)")
        return

    try:
        ee_base = "projects/mrv-cambodia/assets/trainingData"
        before_asset = f"{ee_base}/{split}Before/{chip_id}"
        after_asset  = f"{ee_base}/{split}After/{chip_id}"
        label_asset  = f"{ee_base}/{split}Labels/{chip_id}"

        before = ee.Image(before_asset)
        after  = ee.Image(after_asset)
        label  = ee.Image(label_asset)

        lonlat = tuple(before.geometry().centroid().coordinates().getInfo())
        point = ee.Geometry.Point(lonlat)
        region_ll = point.buffer(scale * patch_size / 2, 1).bounds(1)

        for img, path in zip([before, after, label], [before_path, after_path, label_path]):

            img_proj = img.projection()
            proj_info = img_proj.getInfo()
            crs = proj_info.get("crs", "EPSG:4326")
            transform = proj_info.get("transform")

            region_proj = region_ll.transform(img_proj, 1).bounds(1, img_proj)
            coords = region_proj.coordinates().getInfo()[0]
            xs = [c[0] for c in coords]
            ys = [c[1] for c in coords]
            xmin = min(xs)
            xmax = max(xs)
            ymin = min(ys)
            ymax = max(ys)

            url = img.getDownloadURL({
                "region": region_ll,
                "dimensions": [patch_size, patch_size],
                "format": "NPY",
            })
            response = requests.get(url)
            if response.status_code == 429:
                raise exceptions.TooManyRequests(response.text)
            response.raise_for_status()

            raster = np.load(io.BytesIO(response.content), allow_pickle=True)

            if isinstance(raster, np.ndarray) and raster.dtype.names:
                raster = structured_to_unstructured(raster)

            raster = raster.astype(np.float32)

            if raster.ndim == 2:
                raster = raster[np.newaxis, :, :]
            elif raster.ndim == 3:
                if raster.shape[0] != patch_size and raster.shape[1] == patch_size and raster.shape[2] != patch_size:
                    raster = np.transpose(raster, (2, 0, 1))
                elif raster.shape[0] == patch_size and raster.shape[1] == patch_size:
                    raster = np.transpose(raster, (2, 0, 1))

            band_count, height, width = raster.shape

            if transform is not None and len(transform) >= 6:
                x_res = transform[0]
                y_res = abs(transform[4])
            else:
                x_res = (xmax - xmin) / width
                y_res = (ymax - ymin) / height

            affine_transform = rasterio.transform.from_origin(xmin, ymax, x_res, y_res)

            with rasterio.open(
                path,
                "w",
                driver="GTiff",
                height=height,
                width=width,
                count=band_count,
                dtype=raster.dtype,
                crs=crs,
                transform=affine_transform,
            ) as dst:
                dst.write(raster)

        print(f"Downloaded {split}/{chip_id}")

    except Exception as e:
        print(f"Error downloading {split}/{chip_id}: {e}")

Study area and training sample locations across Cambodia. Black squares indicate the locations of satellite image chips used to train and validate the deforestation detection model. Many samples are concentrated in forest landscapes experiencing some of the highest deforestation pressure in the country, including the Prey Lang forest complex in northern Cambodia, forest frontiers in the northeast, and the southwestern Cardamom Mountains. These regions have experienced sustained forest loss associated with agricultural expansion, logging, and land-use conversion, making them critical areas for monitoring deforestation dynamics.

aoi.jpg

aoi.jpg

This step retrieves the before, after, and label image chips from Google Earth Engine for both the training and validation datasets. For each sample, the corresponding images are extracted based on their geographic location, converted to arrays, and saved locally as GeoTIFF files with appropriate spatial reference information.

Show code
import os
import io
import time
import requests
import numpy as np
import rasterio
from concurrent.futures import ThreadPoolExecutor, as_completed
from numpy.lib.recfunctions import structured_to_unstructured
from tqdm import tqdm
import ee


# GLOBAL SESSION
session = requests.Session()


# DOWNLOAD FUNCTION
def fetch_with_retry(url, max_retries=5, backoff=2):
    for attempt in range(max_retries):
        try:
            response = session.get(url, timeout=60)

            if response.status_code == 429:
                raise Exception("Rate limited (429)")

            response.raise_for_status()
            return response

        except Exception:
            if attempt == max_retries - 1:
                raise
            time.sleep(backoff ** attempt)


# DOWNLOAD SINGLE IMAGE
def download_single(img, path, region_ll, proj_info, patch_size):
    url = img.getDownloadURL({
        "region": region_ll,
        "dimensions": [patch_size, patch_size],
        "format": "NPY",
    })

    response = fetch_with_retry(url)

    raster = np.load(io.BytesIO(response.content), allow_pickle=True)

    if isinstance(raster, np.ndarray) and raster.dtype.names:
        raster = structured_to_unstructured(raster)

    raster = raster.astype(np.float32)

    if raster.ndim == 2:
        raster = raster[np.newaxis, :, :]
    elif raster.ndim == 3:
        if raster.shape[0] == patch_size and raster.shape[1] == patch_size:
            raster = np.transpose(raster, (2, 0, 1))

    band_count, height, width = raster.shape

    transform = proj_info.get("transform")
    crs = proj_info.get("crs", "EPSG:4326")

    if transform and len(transform) >= 6:
        x_res = transform[0]
        y_res = abs(transform[4])
        xmin = transform[2]
        ymax = transform[5]
    else:
        coords = region_ll.bounds().getInfo()["coordinates"][0]
        xs = [c[0] for c in coords]
        ys = [c[1] for c in coords]

        xmin = min(xs)
        xmax = max(xs)
        ymin = min(ys)
        ymax = max(ys)

        x_res = (xmax - xmin) / width
        y_res = (ymax - ymin) / height

    affine = rasterio.transform.from_origin(xmin, ymax, x_res, y_res)

    with rasterio.open(
        path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=band_count,
        dtype=raster.dtype,
        crs=crs,
        transform=affine,
    ) as dst:
        dst.write(raster)


# MAIN CHIP FUNCTION
def download_chip_by_index(i, split="train", patch_size=512, scale=10):
    chip_id = f"{i:04d}"

    before_path = f"{base_dir}/{split}/before/{chip_id}.tif"
    after_path = f"{base_dir}/{split}/after/{chip_id}.tif"
    label_path = f"{base_dir}/{split}/label/{chip_id}.tif"

    if all(os.path.exists(p) for p in [before_path, after_path, label_path]):
        return "skipped", split, chip_id

    try:
        ee_base = "projects/mrv-cambodia/assets/trainingData"

        before = ee.Image(f"{ee_base}/{split}Before/{chip_id}")
        after = ee.Image(f"{ee_base}/{split}After/{chip_id}")
        label = ee.Image(f"{ee_base}/{split}Labels/{chip_id}")

        lonlat = before.geometry().centroid().coordinates().getInfo()
        point = ee.Geometry.Point(lonlat)
        region_ll = point.buffer(scale * patch_size / 2, 1).bounds(1)

        proj_info = before.projection().getInfo()

        image_tasks = [
            (before, before_path),
            (after, after_path),
            (label, label_path),
        ]

        with ThreadPoolExecutor(max_workers=3) as executor:
            futures = [
                executor.submit(
                    download_single,
                    img,
                    path,
                    region_ll,
                    proj_info,
                    patch_size,
                )
                for img, path in image_tasks
            ]

            for future in as_completed(futures):
                future.result()

        return "downloaded", split, chip_id

    except Exception as e:
        print(f"Failed to download {split}/{chip_id}: {e}")
        return "failed", split, chip_id


# BATCH DOWNLOAD
def download_batch(tasks, max_workers=5):
    start = time.time()

    downloaded = 0
    skipped = 0
    failed = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(download_chip_by_index, i, split)
            for i, split in tasks
        ]

        for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading chips"):
            status, split, chip_id = future.result()

            if status == "downloaded":
                downloaded += 1
            elif status == "skipped":
                skipped += 1
            elif status == "failed":
                failed += 1

    print(f"\nTOTAL TIME TAKEN: {time.time() - start:.2f} sec")
    print(f"Downloaded: {downloaded}")
    print(f"Skipped: {skipped}")
    print(f"Failed: {failed}")


# Create necessary directories
for split in ["train", "val"]:
    for sub in ["before", "after", "label"]:
        os.makedirs(f"{base_dir}/{split}/{sub}", exist_ok=True)


print("Starting data download...")

download_batch([(i, "train") for i in range(50)], max_workers=5)
download_batch([(i, "val") for i in range(20)], max_workers=5)

print("Data download complete!")

8 Step 5: Data Inspection and Visualization

Before proceeding with model training, it’s crucial to inspect the data.
Visualizing the image chips will help verify that the downloaded data is correct and consistent.

8.1 What to Look For:

  • Correct alignment between the before and after images.
  • Accurate labeling of deforestation areas.
  • Proper data format (shape, data type).

8.1.1 Steps:

  1. Load random samples from the training set.
  2. Display the before, after, and label images side by side.
  3. Print the shape and type of the data to ensure correctness.
Show code
import random
import numpy as np
import matplotlib.pyplot as plt
from osgeo import gdal

def visualize_triplet(before_path, after_path, label_path):
    """
    Visualize the before, after, and label image chips side by side.
    """
    before = gdal.Open(before_path).ReadAsArray().astype(np.float32)
    after = gdal.Open(after_path).ReadAsArray().astype(np.float32)
    label = gdal.Open(label_path).ReadAsArray()

    # Extract RGB channels (3, 2, 1)
    before_rgb = before[[2, 1, 0], :, :]
    after_rgb = after[[2, 1, 0], :, :]

    # Move channels to the last dimension for display
    before_rgb = np.moveaxis(before_rgb, 0, -1)
    after_rgb = np.moveaxis(after_rgb, 0, -1)

    # Plot the images
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.title("Before Image (RGB)")
    plt.imshow(before_rgb / 3000)  # Normalize for display
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("After Image (RGB)")
    plt.imshow(after_rgb / 3000)  # Normalize for display
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("Deforestation Label")
    plt.imshow(label, cmap="gray")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

# Select random sample from training data
sample_id = random.randint(0, 49)  # Assuming 50 training samples
chip_id = f"{sample_id:04d}"
before_path = f"{base_dir}/train/before/{chip_id}.tif"
after_path = f"{base_dir}/train/after/{chip_id}.tif"
label_path = f"{base_dir}/train/label/{chip_id}.tif"

print(f"Visualizing sample {chip_id}...")
visualize_triplet(before_path, after_path, label_path)

9 Step 6: Cloning the Clay Model Repository

The Clay model is developed as an open-source project; readers interested in the model architecture, pretrained checkpoints, and ongoing development can refer to the public repository and documentation mention at the top.

To train the Clay Deforestation Segmentation Model, we need to clone the model repository from GitHub.
This repository contains the model architecture and necessary utility functions.

9.1 Why Clone the Repository?

  • Access to the Clay model definition.
  • Utility scripts for encoding and training.
  • Up-to-date model versions directly from the source.

9.1.1 Steps:

  1. Clone the repository from GitHub.
  2. Add the cloned folder to the Python path.
  3. Verify that the necessary modules are accessible.
Show code
# Clone the Clay model repository
!git clone --branch v1.0 https://github.com/Clay-foundation/model.git

# Add the cloned model directory to the Python path
import sys
sys.path.append("model")

# Verify that the Encoder class is accessible
try:
    from src.model import Encoder
    print("✅ Clay model successfully imported!")
except ImportError:
    print("❌ Failed to import the Clay model. Check the path and repository structure.")

10 Step 7: Downloading the Clay Model from Hugging Face

In this step, we will download the pre-trained Clay model checkpoint from the Hugging Face Model Hub.
This model serves as the foundation for our landcover change detection task.

💡 Why Use the Hugging Face Model Hub?
- Centralized repository for sharing and accessing pre-trained models.
- Reliable hosting and easy integration into machine learning workflows.
- Automatically handles caching and versioning of models.

🔧 Installation:
We will install the huggingface_hub library to facilitate model download.
The model checkpoint will be saved locally for use in training and inference.

✅ Expected Outcome:
The Clay model checkpoint will be successfully downloaded and stored in the specified location.

Show code
# Install huggingface_hub for downloading models
!pip install huggingface_hub

from huggingface_hub import hf_hub_download

# Download the Clay model checkpoint from Hugging Face
ckpt_path = hf_hub_download(
    repo_id="made-with-clay/Clay",
    filename="clay-v1-base.ckpt",
    revision="1bb95c0d09239ebcffd76dd8939dd745a1d95dfe"
)


print(f"✅ Model downloaded and saved at: {ckpt_path}")

““” # Step 8: Training Configuration

In this step, we will set up the training pipeline for the Clay Deforestation Segmentation Model.
This includes configuring the model, defining training parameters, and initializing the training loop.

10.1 Training Setup:

  1. Install the required libraries for training.
  2. Define hyperparameters (e.g., learning rate, batch size).
  3. Load the Clay model and set up the PyTorch Lightning Trainer.
  4. Configure checkpoints and logging for monitoring training progress.

10.1.1 Libraries Needed for Training:

  • torch: Deep learning framework.
  • torchvision: Data transformations and model utilities.
  • lightning: Simplifies training and model management.
  • segmentation_models_pytorch: Advanced loss functions and metrics.
    ““”
Show code
!pip install torch torchvision lightning segmentation-models-pytorch

11 Step 9: Importing Libraries and Setting Up the Environment

In this step, we will import the necessary libraries to build, train, and evaluate the Clay-Based Landcover Change Detection model.

11.0.1 💡 Why These Libraries?

  • torch: Core deep learning library for model building and training.
  • torch.nn: Contains modules and functions for defining neural network architectures.
  • torch.nn.functional: Provides functions for activation functions, loss calculations, and more.
  • torch.utils.data: Utilities for data loading and batching.
  • torchvision.transforms: Transformations for image preprocessing and augmentation.
  • einops: Efficient manipulation and rearrangement of multi-dimensional arrays.
  • lightning: Simplifies model training with clean and modular code.
  • torchmetrics: Collection of metrics for evaluating model performance (IoU, F1, Precision, Recall).
  • segmentation_models_pytorch: Advanced loss functions and segmentation metrics.

11.0.2 ⚙️ Device Check:

  • To utilize GPU acceleration, we check for the availability of CUDA.
  • If a GPU is detected, the model will run on CUDA; otherwise, it will use the CPU.

11.0.3Expected Outcome:

  • Successful import of libraries.
  • Confirmation of the device being used (CUDA or CPU).
Show code
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from einops import rearrange, repeat
import math
import os

import lightning as L
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score, BinaryPrecision, BinaryRecall
import segmentation_models_pytorch as smp

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

12 Step 10: Building the Clay-Based Segmentor Model

In this step, we will build the Segmentor model using the Clay Foundation Model.

The model architecture is designed to leverage the power of the Vision Transformer (ViT) for spatial feature extraction, combined with a cross-attention mechanism to capture temporal changes between two images. Vision Transformers are particularly well suited to segmentation because they model long-range spatial dependencies across image patches, allowing the network to incorporate global context when identifying boundaries and fragmented patterns in land-cover change. Prior work has demonstrated strong performance of transformer-based architectures for semantic segmentation tasks (Strudel et al., 2021; Dosovitskiy et al., 2020).

The model uses a cross-attention mechanism to fuse tokenized feature representations from the “before” and “after” images, enabling direct interaction between temporal embeddings while preserving spatial context; this mechanism follows the encoder-decoder attention formulation introduced in the original Transformer and is well aligned with later transformer-based remote-sensing change detection frameworks. (Vaswani et al., 2017; Chen et al., 2022)

12.0.1 💡 Model Architecture Overview:

  1. SegmentEncoder:
    • A custom encoder built on the Clay Foundation Model.
    • Uses a Feature Pyramid Network (FPN) to generate multi-scale features.
    • Incorporates a Vision Transformer (ViT) backbone to extract spatial features.
    • Supports multi-scale feature extraction using intermediate layers.
    • Implements positional encoding to encode spatial information.
  2. Cross-Attention Mechanism:
    • Uses FlashCrossAttention to fuse temporal information from before-and-after images.
    • Efficiently computes cross-attention between feature maps of paired images.
  3. Upsampling Layers:
    • Uses a UNet-style decoder to upsample fused features.
    • Multiple convolutional and transposed convolutional layers for finer resolution.
  4. Final Segmentation Head:
    • Uses a convolutional layer to output the final segmentation mask.

12.0.2Expected Outcome:

  • The model will produce a binary mask indicating changes in landcover between two satellite images.
  • Multi-scale feature extraction and cross-attention will improve temporal consistency and spatial accuracy.
  • The model will be efficient and robust for detecting deforestation patterns.
Show code
import sys
sys.path.append("/content/model/src")

# SegmentEncoder: Extract multi-scale ViT features from datacubes
from torch.nn.functional import scaled_dot_product_attention

from src.model import Encoder

class SegmentEncoder(Encoder):
    """
    Encoder class for segmentation tasks, incorporating a feature pyramid
    network (FPN).

    Attributes:
        feature_maps (list): Indices of layers to be used for generating
        feature maps.
        ckpt_path (str): Path to the clay checkpoint file.
    """

    def __init__(  # noqa: PLR0913
        self,
        mask_ratio,
        patch_size,
        shuffle,
        dim,
        depth,
        heads,
        dim_head,
        mlp_ratio,
        ckpt_path=None,
    ):
        super().__init__(
            mask_ratio,
            patch_size,
            shuffle,
            dim,
            depth,
            heads,
            dim_head,
            mlp_ratio,
        )

        # Set device
        self.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        # Load model from checkpoint if provided
        self.load_from_ckpt(ckpt_path)

        self.feature_layers = [2, 5, 8, 11]

        #self.fusion_layer = FeatureFusion(num_features=12, feature_dim=768)

    def load_from_ckpt(self, ckpt_path):
        """
        Load the model's state from a checkpoint file.

        Args:
            ckpt_path (str): The path to the checkpoint file.
        """
        if ckpt_path:
            # Load checkpoint
            ckpt = torch.load(ckpt_path, map_location=self.device)
            state_dict = ckpt.get("state_dict")

            # Prepare new state dict with the desired subset and naming
            new_state_dict = {
                re.sub(r"^model\.encoder\.", "", name): param
                for name, param in state_dict.items()
                if name.startswith("model.encoder")
            }

            # Load the modified state dict into the model
            model_state_dict = self.state_dict()
            for name, param in new_state_dict.items():
                if (
                    name in model_state_dict
                    and param.size() == model_state_dict[name].size()
                ):
                    model_state_dict[name].copy_(param)
                else:
                    print(f"No matching parameter for {name} with size {param.size()}")

            # Freeze the loaded parameters
            for name, param in self.named_parameters():
                if name in new_state_dict:
                    param.requires_grad = False


    def generate_positional_encoding(self,shape):
        """
        Generate fixed sinusoidal positional encodings.

        Args:
            shape (tuple): Expected shape (B, L, D) where:
                          - B = Batch size
                          - L = Sequence length (number of patches)
                          - D = Embedding dimension
            device (torch.device): The device to allocate tensors.

        Returns:
            torch.Tensor: Positional encoding of shape (B, L, D).
        """
        B, L, D = shape
        pos = torch.arange(L, device=self.device).unsqueeze(1)  # (L, 1)
        div_term = torch.exp(torch.arange(0, D, 2, device=self.device) * (-math.log(10000.0) / D))

        pe = torch.zeros(L, D, device=self.device)
        pe[:, 0::2] = torch.sin(pos * div_term)  # Apply sin to even indices
        pe[:, 1::2] = torch.cos(pos * div_term)  # Apply cos to odd indices

        pe = pe.unsqueeze(0).expand(B, -1, -1)  # Expand to match batch size
        return pe  # Shape: (B, L, D)



    def forward(self, datacube):
        """
        Forward pass of the SegmentEncoder.

        Args:
            datacube (dict): A dictionary containing the input datacube and
                meta information like time, latlon, gsd & wavelenths.

        Returns:
            list: A list of feature maps extracted from the datacube.
        """
        cube, time, latlon, gsd, waves = (
            datacube["pixels"],  # [B C H W]
            datacube["time"],  # [B 2]
            datacube["latlon"],  # [B 2]
            datacube["gsd"],  # 1
            datacube["waves"],  # [N]
        )

        B, C, H, W = cube.shape

        # Patchify and create embeddings per patch
        patches, waves_encoded = self.to_patch_embed(cube, waves)  # [B L D]
        patches = self.add_encodings(patches, time, latlon, gsd)  # [B L D]

        # Add class tokens
        cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B)  # [B 1 D]
        patches = torch.cat((cls_tokens, patches), dim=1)  # [B (1 + L) D]

        # Extract features from multiple depths
        feature_maps = []

        # Iterate over transformer layers to capture intermediate outputs
        for idx, (attn, ff) in enumerate(self.transformer.layers):
            patches = attn(patches) + patches
            patches = ff(patches) + patches
            if idx in self.feature_layers:
                feature_maps.append(patches[:, 1:, :])  # Exclude class tok


        positional_encodings = self.generate_positional_encoding(patches[:, 1:, :].shape)
        return feature_maps, positional_encodings  # Return multiple feature maps


class FlashCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_p=0.0):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.dropout_p = dropout_p

        # Adjust linear projection input dimension
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

        self.norm_before = nn.LayerNorm(d_model)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.ReLU(),
            nn.Linear(d_model * 2, d_model),
        )


    def forward(self, fused_enc, fused_enc_k_v, fused_pos, fused_pos_k_v):
        fused_enc_norm = self.norm_before(fused_enc)
        fused_enc_k_v_norm = self.norm_before(fused_enc_k_v)

        Q = self.q_proj(fused_enc_norm)
        K = self.k_proj(fused_enc_k_v_norm + fused_pos_k_v)
        V = self.v_proj(fused_enc_k_v_norm)

        B, seq_len, _ = Q.shape
        Q = Q.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attn_output = scaled_dot_product_attention(Q, K, V, dropout_p=self.dropout_p)

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, self.d_model)

        fused_features = attn_output + fused_enc

        fused_features = self.ffn(fused_features) + fused_features

        return fused_features

class Segmentor(nn.Module):
    def __init__(self, num_classes, ckpt_path):
        super().__init__()

        self.encoder = SegmentEncoder(
            mask_ratio=0.0,
            patch_size=8,
            shuffle=False,
            dim=768,
            depth=12,
            heads=12,
            dim_head=64,
            mlp_ratio=4.0,
            ckpt_path=ckpt_path,
        )

        self.cross_attention_layer = FlashCrossAttention(d_model=768, num_heads=8)

        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(768, 768 // 4, kernel_size=4, stride=2, padding=1),  # 2x upsampling
            nn.BatchNorm2d(768 // 4),
            nn.GELU(),
            nn.Conv2d(768 // 4, 768 // 4, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(768 // 4, 768 // 8, kernel_size=4, stride=2, padding=1),  # 2x upsampling
            nn.GELU(),
            nn.Conv2d(768 // 8, 768 // 8, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(768 // 8, 768 // 16, kernel_size=4, stride=2, padding=1),  # **Additional 2x upsampling**
            nn.GELU(),
        )

        # Feature fusion layers for multiple depths
        self.reduce_enc = nn.ModuleList([nn.Linear(768 * 2, 768) for _ in range(4)])
        self.reduce_pos = nn.ModuleList([nn.Linear(768 * 2, 768) for _ in range(4)])

        # Final segmentation head
        self.conv_out = nn.Conv2d(768 // 16, num_classes, kernel_size=1)

    def forward(self, before_datacube, after_datacube):
        """
        Forward pass of the Segmentor.
        """

        before_feats, before_pos = self.encoder(before_datacube)
        after_feats, after_pos = self.encoder(after_datacube)

        # Process each level of features
        fused_features = []
        for i in range(len(before_feats)):
            reduced_enc = self.reduce_enc[i](torch.cat([before_feats[i], after_feats[i]], dim=-1))

            reduced_pos = self.reduce_pos[i](torch.cat([before_pos, after_pos], dim=-1))

            fused_features.append(self.cross_attention_layer(reduced_enc, reduced_enc, reduced_pos, reduced_pos))

        # Merge features from multiple levels
        fused_features = torch.stack(fused_features, dim=-1).sum(dim=-1)  # Simple summation

        B, C, H_in, W_in = before_datacube['pixels'].shape

        H_patches = H_in // self.encoder.patch_size
        W_patches = W_in // self.encoder.patch_size

        features = rearrange(fused_features, "B (H W) D -> B D H W", H=H_patches, W=W_patches)

        # Upscale Features
        upscaled_features = self.output_upscaling(features )

        # Generate Final Segmentation Mask
        mask_pred = self.conv_out(upscaled_features)


        return mask_pred

13 Step 11: Building the changeSegmentor Lightning Module

In this step, we define the changeSegmentor class, which inherits from LightningModule.
This module integrates the segmentation model with PyTorch Lightning, allowing for structured training and evaluation.

13.0.1 💡 Why Use LightningModule?

  • Simplifies the training loop and optimizer configuration.
  • Enables efficient GPU training and logging.
  • Provides flexibility to define training and validation steps separately.

13.0.2 ⚙️ Key Components of changeSegmentor:

  1. Initialization (__init__):
    • Loads the Clay Segmentor model for deforestation detection.
    • Initializes loss functions and evaluation metrics.
    • Uses AdamW optimizer with cosine annealing learning rate scheduler.
    • Defines spectral bands (wavelengths) and ground sample distance (GSD) for Sentinel-2 data.
  2. Forward Method:
    • Takes before and after datacubes as inputs.
    • Adds wavelength and GSD metadata to the datacubes.
    • Returns the model’s prediction for landcover change.
  3. Loss Functions:
    • Uses a combination of:
      • BCE Loss: Handles class imbalance.
      • Dice Loss: Optimizes for overlapping regions between prediction and ground truth.
      • Tversky Loss: Balances false positives and false negatives.
    • Loss formula: 0.4 * BCE + 0.3 * Dice + 0.3 * Tversky
  4. Metrics:
    • Intersection over Union (IoU): Measures overlap between predicted and actual masks.
    • F1 Score: Balances precision and recall.
    • Precision and Recall: Measure how accurately deforestation is detected.
  5. Shared Step Method:
    • Used for both training and validation.
    • Computes the loss and evaluation metrics for a given batch.
    • Logs the metrics for real-time monitoring.
  6. Training and Validation Steps:
    • Uses the shared step method to perform forward passes during both training and validation.
    • Allows for consistent loss and metric calculation across different phases.

13.0.3Expected Outcome:

  • A robust and efficient training pipeline for landcover change detection.
  • The model will learn to accurately identify deforestation from paired satellite images.
  • Real-time monitoring of loss, IoU, F1, Precision, and Recall during training.
Show code
class changeSegmentor(L.LightningModule):
    def __init__(self, lr=1e-4, wd=1e-2, b1=0.9, b2=0.999):
        super().__init__()

        self.save_hyperparameters()

        self.model = Segmentor(num_classes=1,ckpt_path=ckpt_path)

        # Loss functions
        self.dice_loss = smp.losses.DiceLoss(mode="binary")
        self.focal_loss = smp.losses.FocalLoss(mode="binary", alpha=0.25, gamma=2.0)
        self.tversky_loss = smp.losses.TverskyLoss(mode="binary", alpha=0.6, beta=0.4)
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]))

        # Metrics
        self.iou = BinaryJaccardIndex()
        self.f1 = BinaryF1Score()
        self.precision = BinaryPrecision()
        self.recall = BinaryRecall()

        self.waves = torch.tensor([0.493, 0.56, 0.665, 0.704, 0.74, 0.783, 0.842, 0.865, 1.61, 2.19])
        self.gsd = torch.tensor(10.0)

    def forward(self, before_datacube, after_datacube):

        before_datacube['waves'] = self.waves
        after_datacube['waves'] = self.waves
        before_datacube['gsd'] = self.gsd
        after_datacube['gsd'] = self.gsd

        return self.model(before_datacube, after_datacube)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.hparams.lr,
            weight_decay=self.hparams.wd,
            betas=(self.hparams.b1, self.hparams.b2),
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=100,
            T_mult=1,
            eta_min=self.hparams.lr * 1e-2,
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def shared_step(self, batch, phase):
        before, after = batch["before"], batch["after"]
        label = batch["label"][:, 0, :, :]  # assume binary label in shape [B, 1, H, W]

        output = self(before, after)[:, 0, :, :]  # reduce channel dim to [B, H, W]

        # Compute losses
        bce = self.bce_loss(output, label)
        dice = self.dice_loss(output, label)
        tversky = self.tversky_loss(output, label)

        loss = 0.1 * bce + 0.3 * dice + 0.6 * tversky

        pred = (torch.sigmoid(output) > 0.5).float()

        # Handle edge case: empty masks
        if label.sum() == 0 and pred.sum() == 0:
            iou = torch.tensor(1.0, device=self.device)
        else:
            iou = self.iou(pred, label)

        f1 = self.f1(pred, label)
        precision = self.precision(pred, label)
        recall = self.recall(pred, label)

        self.log_dict({
            f"{phase}/loss": loss,
            f"{phase}/iou": iou,
            f"{phase}/f1": f1,
            f"{phase}/precision": precision,
            f"{phase}/recall": recall,
        }, prog_bar=True, sync_dist=True)

        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

14 Step 12: Building the Sentinel2ChangeDataset Class

In this step, we will define a custom dataset class to handle the loading and preprocessing of satellite image pairs.
This class will be used to prepare input data for training the Clay Deforestation Segmentation Model.

14.0.1 💡 Why Use a Custom Dataset Class?

  • Efficiently loads and preprocesses before-and-after satellite images.
  • Provides a structured way to handle large datasets.
  • Integrates seamlessly with PyTorch’s DataLoader for batching and shuffling.
  • Allows for on-the-fly data normalization, reducing preprocessing time.

14.0.2 ⚙️ How It Works:

  1. Initialization:
    • Takes directories for before, after, and label images as input.
    • Extracts image file paths and stores them in lists.
    • Sets the mean and standard deviation for Sentinel-2 bands.
    • Defines a normalization transform for input consistency.
  2. Loading and Preprocessing:
    • Uses rasterio to read GeoTIFF images efficiently.
    • Normalizes the images using preset mean and standard deviation values.
    • Creates a datacube structure containing:
      • Pixel values
      • Placeholder time information
      • Placeholder geographic coordinates
  3. Data Retrieval:
    • Implements __getitem__ to return before, after, and label data as a dictionary.
    • Uses the transformation pipeline to standardize the data format.
  4. Dataset Length:
    • Returns the total number of image pairs available.

14.0.3Expected Outcome:

  • Efficient data loading during training and validation.
  • Preprocessed and normalized images are ready for model ingestion.
  • The model will receive consistent input batches, improving training stability.
Show code
from torch.utils.data import Dataset, DataLoader
import rasterio
import numpy as np
from pathlib import Path

class Sentinel2ChangeDataset(Dataset):
    def __init__(self, before_dir, after_dir, label_dir, transform=None):
        self.before_dir = Path(before_dir)
        self.after_dir = Path(after_dir)
        self.label_dir = Path(label_dir)

        self.before_files = sorted(self.before_dir.glob("*.tif"))
        self.after_files = sorted(self.after_dir.glob("*.tif"))
        self.label_files = sorted(self.label_dir.glob("*.tif"))

        mean_values = np.array([1105., 1355., 1552., 1887., 2422., 2630., 2743., 2785., 2388., 1835.])
        std_values = np.array([1809., 1757., 1888., 1870., 1732., 1697., 1742., 1648., 1470., 1379.])

        self.transform = self.create_transforms(
            mean=mean_values,
            std = std_values,
        )


    def create_transforms(self, mean, std):
        """
        Create normalization transforms.

        Args:
            mean (list): Mean values for normalization.
            std (list): Standard deviation values for normalization.

        Returns:
            torchvision.transforms.Compose: A composition of transforms.
        """
        return v2.Compose([
            v2.Normalize(mean=mean, std=std),
        ])

    def load_tiff(self, path):
        with rasterio.open(path) as src:
            arr = src.read()
        return arr

    def __getitem__(self, idx):
        before = self.load_tiff(self.before_files[idx])
        after = self.load_tiff(self.after_files[idx])
        label = self.load_tiff(self.label_files[idx]).astype(np.float32)

        before = self.transform(torch.from_numpy(before ))
        after =self.transform( torch.from_numpy(after ))
        label = torch.from_numpy(label)

        datacube_before = {
            "pixels": before,
            "time": torch.zeros(4),
            "latlon": torch.zeros(4),
        }

        datacube_after = {
            "pixels": after,
            "time": torch.zeros(4),
            "latlon": torch.zeros(4),
        }

        return {
            "before": datacube_before,
            "after": datacube_after,
            "label": label,
        }

    def __len__(self):
        return len(self.before_files)

15 Step 13: Training the Model with PyTorch Lightning

In this step, we will train the Clay Deforestation Segmentation Model using PyTorch Lightning.
Lightning simplifies the training loop, handles device management, and provides a structured way to monitor model performance.

15.0.1 💡 Why Use PyTorch Lightning?

  • Streamlines the training and validation process.
  • Supports multi-GPU training with minimal configuration.
  • Integrated logging and checkpointing.
  • Enables efficient monitoring through TensorBoard.

15.0.2 ⚙️ Training Process:

  1. Dataset Initialization:
    • Creates training and validation datasets using the Sentinel2ChangeDataset class.
    • Sets the paths for before, after, and label images.
  2. Data Loaders:
    • Uses DataLoader to batch and shuffle data.
    • Enables efficient data fetching during training and validation.
    • Configured with a batch size of 4 for training and 1 for validation.
  3. Model Initialization:
    • Loads the ChangeSegmentor model.
    • Sets the model for training.
  4. Logger Configuration:
    • Uses the TensorBoardLogger to log training metrics.
    • Logs are stored in the lightning_logs/ directory.
  5. Trainer Setup:
    • Configures the Trainer with the following settings:
      • Max Epochs: 3
      • Accelerator: GPU (if available)
      • Logger: TensorBoard for real-time monitoring
      • Progress Bar: Enabled for cleaner output
      • Logging Frequency: Logs after every step
  6. Training the Model:
    • Uses trainer.fit() to start the training process.
    • Combines the training and validation data loaders for efficient training and evaluation.

15.0.3Expected Outcome:

  • The model will be trained for the specified number of epochs.
  • Training progress and metrics will be logged and can be visualized with TensorBoard.
  • The model will progressively improve its performance in detecting deforestation patterns.

15.0.4 note:

Model training was executed using a T4 GPU and completed in approximately 35 minutes. Runtime may vary depending on GPU type and system load, but a training time on the order of tens of minutes indicates the process is running as expected.”

Show code
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
import re

train_dataset = Sentinel2ChangeDataset(
    before_dir=f"{base_dir}/train/before/",
    after_dir=f"{base_dir}/train/after/",
    label_dir=f"{base_dir}/train/label/",
)

val_dataset = Sentinel2ChangeDataset(
    before_dir=f"{base_dir}/val/before/",
    after_dir=f"{base_dir}/val/after/",
    label_dir=f"{base_dir}/val/label/",
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

model = changeSegmentor()
if hasattr(model.model, "gradient_checkpointing_enable"):
    model.model.gradient_checkpointing_enable()

tb_logger = TensorBoardLogger("lightning_logs", name="clay_model")

trainer = Trainer(
    max_epochs=20,
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    logger=tb_logger,
    enable_progress_bar=True,
    log_every_n_steps=1,
)

# Optional: PyTorch 2.x compilation
# For long-running or large-scale training jobs, you may experiment with:
#   import torch
#   model.model = torch.compile(model.model, mode="reduce-overhead")
# or:
#   model.model = torch.compile(model.model, mode="max-autotune")
# Note: For short Colab runs, compilation overhead may increase total runtime.

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

16 Step 14: Visualizing Training Progress with TensorBoard

In this step, we will visualize the model training metrics using TensorBoard.
TensorBoard provides an interactive dashboard to monitor loss, accuracy, and other performance metrics throughout the training process.

💡 Why Use TensorBoard?
- Real-time tracking of training and validation metrics.
- Visualization of loss curves, accuracy, and other custom metrics.
- Easy comparison between different training runs.
- Helps identify overfitting or underfitting during training.

🔧 How to Access TensorBoard:
- Run the cell below to launch TensorBoard.
- Open the generated link to view the training logs.
- The logs are stored in the lightning_logs/ directory.

✅ Expected Outcome:
A TensorBoard interface displaying training progress, including metrics like IoU, F1 Score, Precision, and Recall.

Show code
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

TensorBoard provides a quick way to assess whether the model is learning meaningful patterns from the data. When reviewing the training curves, focus on overall trends rather than individual fluctuations. In this example, the training F1 score and IoU steadily increase over time, indicating that the model is improving its ability to correctly identify deforestation pixels relative to the background class. Precision also increases, suggesting the model is becoming more selective and reducing false positives. At the same time, recall begins very high and gradually decreases slightly as training progresses, which is typical as the model transitions from over-predicting change to making more balanced predictions. The validation metrics show similar improvement trends, with validation F1 and IoU stabilizing toward the later epochs, suggesting the model is generalizing reasonably well to unseen samples. The key takeaway is to look for consistent improvement in IoU and F1 alongside decreasing loss and relatively stable validation curves; these patterns indicate that the model is learning useful spatial-temporal features for deforestation detection rather than simply memorizing the training data.

17 Step 15: Evaluation Metrics for Deforestation Detection

To assess model performance, predictions were compared with ground-truth deforestation masks at the pixel level. Each pixel was classified as true positive (correctly detected deforestation), false positive (non-deforested area predicted as deforested), or false negative (missed deforestation).

From these counts, the following metrics were computed: Precision: Proportion of predicted deforested pixels that are truly deforested (measures false alarms).

Recall: Proportion of actual deforested pixels correctly detected (measures missed deforestation).

Intersection over Union (IoU): Spatial overlap between predicted and true deforested areas.

F1 Score: Harmonic mean of precision and recall, providing a balanced measure under class imbalance.

All metrics were calculated over the entire validation dataset using binarized model outputs (threshold = 0.5)

Show code
import torch

# Put model in evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Put trained Lightning model on GPU/CPU
model = model.to(device)
model.eval()


# Counters for confusion values
total_tp = 0
total_fp = 0
total_fn = 0


# No gradients needed during evaluation
with torch.no_grad():
    # Loop through validation dataset
    for batch in val_loader:
        # Move data to device (CPU or GPU)
        before_datacube = {k: v.to(device) for k, v in batch["before"].items()}
        after_datacube  = {k: v.to(device) for k, v in batch["after"].items()}
        gt = batch["label"].to(device)  # Ground truth mask

        # Model prediction (logits)
        preds = model(before_datacube, after_datacube)

        # Convert logits to probabilities
        preds = torch.sigmoid(preds)

        # Convert probabilities to binary values (0 or 1)
        preds = (preds > 0.5).float()

        # Flatten tensors so we compare all pixels together
        preds = preds.view(-1)
        gt = gt.view(-1)

        # Count confusion components
        tp = torch.sum((preds == 1) & (gt == 1)).item()
        fp = torch.sum((preds == 1) & (gt == 0)).item()
        fn = torch.sum((preds == 0) & (gt == 1)).item()

        total_tp += tp
        total_fp += fp
        total_fn += fn


# -----------------------------
# Calculate Metrics
# -----------------------------
precision = total_tp / (total_tp + total_fp + 1e-8)
recall = total_tp / (total_tp + total_fn + 1e-8)
iou = total_tp / (total_tp + total_fp + total_fn + 1e-8)
f1_score = 2 * precision * recall / (precision + recall + 1e-8)


# -----------------------------
# Print Results
# -----------------------------
print("Validation Results")
print("------------------")
print(f"Precision : {precision:.4f}")
print(f"Recall    : {recall:.4f}")
print(f"IoU       : {iou:.4f}")
print(f"F1 Score  : {f1_score:.4f}")

This step prepares the trained model for inference. Instead of learning from data, the model is now used to generate predictions on unseen validation samples.

Show code

# Create figure with 10 rows x 4 columns
fig, axs = plt.subplots(10, 4, figsize=(16, 40))

# Helper to normalize image for display
def normalize(img):
    img = img - img.min()
    img = img / (img.max() + 1e-6)
    return img

# Loop over 10 batches from val_loader
val_iter = iter(val_loader)

for i in range(10):
    batch = next(val_iter)

    # Move data to the same device as the model
    before_datacube = {k: v.to(device) for k, v in batch["before"].items()}
    after_datacube = {k: v.to(device) for k, v in batch["after"].items()}

    try:
        with torch.no_grad():
            preds = model(before_datacube, after_datacube)
            preds = preds.sigmoid()
            pred_mask = (preds > 0.5).float()

        # Get the first (only) sample in the batch
        before_img = before_datacube["pixels"][0, :3].cpu().numpy()
        after_img = after_datacube["pixels"][0, :3].cpu().numpy()
        pred_mask_img = pred_mask[0, 0].cpu().numpy()
        gt_mask_img = batch["label"][0, 0].cpu().numpy()

        axs[i, 0].imshow(normalize(before_img.transpose(1, 2, 0)))
        axs[i, 0].set_title(f"Sample {i+1} - Before (RGB)")
        axs[i, 1].imshow(normalize(after_img.transpose(1, 2, 0)))
        axs[i, 1].set_title("After (RGB)")
        axs[i, 2].imshow(gt_mask_img, cmap="gray")
        axs[i, 2].set_title("Ground Truth")
        axs[i, 3].imshow(pred_mask_img, cmap="gray")
        axs[i, 3].set_title("Prediction")

        for j in range(4):
            axs[i, j].axis("off")

    except RuntimeError as e:
        print(f"❌ RuntimeError on batch {i+1}: {e}")

plt.tight_layout()
plt.show()

18 Conclusion

Foundation models such as Clay represent a shift in how remote sensing analysis can be performed, as they are pre-trained on large volumes of satellite imagery and learn generalizable feature representations that can be adapted to new tasks such as deforestation detection. By leveraging a Vision Transformer backbone, the model captures broader spatial context and temporal differences between “before” and “after” observations, allowing it to highlight forest disturbance patterns without requiring task-specific feature engineering. In this example, the model successfully identifies the general location of deforestation when compared with the ground-truth mask, demonstrating the potential of foundation-model-based workflows for monitoring forest change. However, the prediction appears smoother than the ground truth and includes a vertical “blob”-like artifact, reflecting the limitations of patch-based and convolutional-style decoding approaches that can blur boundaries and generalize disturbances into broader shapes. As a result, while foundation models offer a powerful new paradigm for remote sensing by transferring learned representations across applications, careful tuning, validation, and interpretation remain important when applying them to operational deforestation monitoring tasks.

18.1 References

Aryal, R.R., Wespestad, C., Kennedy, R., Dilger, J., Dyson, K., Bullock, E., Khanal, N., Kono, M., Poortinga, A., Saah, D. and Tenneson, K. (2021). Lessons learned while implementing a time-series approach to forest canopy disturbance detection in Nepal. Remote Sensing, 13(14), 2666.

Chen, H., Qi, Z. and Shi, Z., 2021. Remote sensing image change detection with transformers. IEEE Transactions on Geoscience and Remote Sensing, 60, pp.1-14.

Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.

Gorelick, N., Hancher, M., Dixon, M., Ilyushchenko, S., Thau, D. and Moore, R., 2017. Google Earth Engine: Planetary-scale geospatial analysis for everyone. Remote sensing of Environment, 202, pp.18-27.

Kilbride, J.B., Poortinga, A., Bhandari, B., Thwal, N.S., Quyen, N.H., Silverman, J., Tenneson, K., Bell, D., Gregory, M., Kennedy, R. and Saah, D. (2023). Near real-time mapping of tropical forest disturbance using SAR and semantic segmentation in Google Earth Engine. Remote Sensing, 15(21), 5223.

Strudel, R., Garcia, R., Laptev, I. and Schmid, C., 2021. Segmenter: Transformer for semantic segmentation. In Proceedings of the IEEE/CVF international conference on computer vision (pp. 7262-7272).

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł. and Polosukhin, I., 2017. Attention is all you need. Advances in neural information processing systems, 30.