Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: CI

on:
push:
branches: [main]
pull_request:

jobs:
test:
name: Build & test (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.11"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install torch (CPU)
# Install the CPU-only torch build first to keep CI lightweight and avoid
# pulling large CUDA wheels via transitive dependencies.
run: |
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu

- name: Install package
run: pip install -e ".[dev,examples]"

- name: Import smoke test
run: python -c "import satclip; print('satclip public API:', satclip.__all__)"

- name: Run tests
run: pytest -q
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SatCLIP files
satclip_logs/

# Python
__pycache__/
*.egg-info/
build/
dist/
.pytest_cache/
26 changes: 20 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,27 @@

SatCLIP trains location and image encoders via contrastive learning, by matching images to their corresponding locations. This is analogous to the CLIP approach, which matches images to their corresponding text. Through this process, the location encoder learns characteristics of a location, as represented by satellite imagery. For more details, check out our [paper](https://arxiv.org/abs/2311.17179).

## Installation

SatCLIP is packaged as a Python module. Install it directly from GitHub:
```bash
pip install git+https://github.com/microsoft/satclip.git
```
Or, for a local/editable install (e.g. for training or development):
```bash
git clone https://github.com/microsoft/satclip.git
cd satclip
pip install -e .
```
Once installed, `satclip` can be imported from any directory — no need to set `PYTHONPATH`.

## Overview

Usage of SatCLIP is simple:

```python
from model import *
from location_encoder import *
import torch
from satclip.model import SatCLIP

model = SatCLIP(
embed_dim=512,
Expand Down Expand Up @@ -48,11 +62,11 @@ mkdir -p images
for f in data/shard-*.tar; do tar -xf "$f" -C images; done
```

Now, to train **SatCLIP** models, set the paths correctly (point `data.data_dir` in `satclip/configs/default.yaml` to this dataset directory), adapt training configs in `satclip/configs/default.yaml` and train SatCLIP by running:
Now, to train **SatCLIP** models, set the paths correctly (point `data.data_dir` in `satclip/configs/default.yaml` to this dataset directory), adapt training configs in `satclip/configs/default.yaml` and train SatCLIP by running the training module from the repository root:
```bash
cd satclip
python main.py
python -m satclip.main
```
This requires an editable install (`pip install -e .`, see [Installation](#installation)). You can point to a custom config with `python -m satclip.main --config path/to/config.yaml`.

### Use of the S2-100K dataset

Expand Down Expand Up @@ -96,7 +110,7 @@ We provide six pretrained SatCLIP models, trained with different vision encoders
Usage of pretrained models is simple. Simply specify the SatCLIP model you want to access, e.g. `satclip-vit16-l40`:
```python
from huggingface_hub import hf_hub_download
from load import get_satclip
from satclip.load import get_satclip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
8 changes: 2 additions & 6 deletions notebooks/A01_Simple_SatCLIP_Usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@
}
],
"source": [
"!rm -r sample_data .config # Empty current directory\n",
"!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository"
"!pip install git+https://github.com/microsoft/satclip.git"
]
},
{
Expand Down Expand Up @@ -169,11 +168,8 @@
{
"cell_type": "code",
"source": [
"import sys\n",
"sys.path.append('./satclip')\n",
"\n",
"import torch\n",
"from load import get_satclip"
"from satclip.load import get_satclip"
],
"metadata": {
"id": "grEIwoFjoHvu"
Expand Down
10 changes: 3 additions & 7 deletions notebooks/A02_SatCLIP_Hugging_Face_Usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,7 @@
}
],
"source": [
"!rm -r sample_data .config # Empty current directory\n",
"!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository"
"!pip install git+https://github.com/microsoft/satclip.git"
]
},
{
Expand Down Expand Up @@ -473,11 +472,8 @@
{
"cell_type": "code",
"source": [
"import sys\n",
"sys.path.append('./satclip')\n",
"\n",
"import torch\n",
"from load import get_satclip"
"from satclip.load import get_satclip"
],
"metadata": {
"id": "grEIwoFjoHvu"
Expand All @@ -498,7 +494,7 @@
"cell_type": "code",
"source": [
"from huggingface_hub import hf_hub_download\n",
"from load import get_satclip\n",
"from satclip.load import get_satclip\n",
"import torch\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
Expand Down
8 changes: 2 additions & 6 deletions notebooks/B01_Example_Air_Temperature_Prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
}
],
"source": [
"!rm -r sample_data .config # Empty current directory\n",
"!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository"
"!pip install git+https://github.com/microsoft/satclip.git"
]
},
{
Expand Down Expand Up @@ -168,11 +167,8 @@
{
"cell_type": "code",
"source": [
"import sys\n",
"sys.path.append('./satclip')\n",
"\n",
"import torch\n",
"from load import get_satclip\n",
"from satclip.load import get_satclip\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Automatically select device"
],
Expand Down
8 changes: 2 additions & 6 deletions notebooks/B02_Example_Image_Localization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@
}
],
"source": [
"!rm -r sample_data .config # Empty current directory\n",
"!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository"
"!pip install git+https://github.com/microsoft/satclip.git"
]
},
{
Expand Down Expand Up @@ -485,10 +484,7 @@
{
"cell_type": "code",
"source": [
"import sys\n",
"sys.path.append('./satclip')\n",
"\n",
"from load import get_satclip\n",
"from satclip.load import get_satclip\n",
"from huggingface_hub import hf_hub_download\n",
"import torch\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
Expand Down
70 changes: 70 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
[build-system]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is the main change in this PR

requires = ["setuptools >= 77.0"]
build-backend = "setuptools.build_meta"

[project]
name = "satclip"
description = "A global, general-purpose geographic location encoder"
version = "0.0.1"
authors = [
{name="Konstantin Klemmer"},
{name="Esther Rolf"},
{name="Caleb Robinson"},
{name="Lester Mackey"},
{name="Marc Rußwurm"},
]
readme = "README.md"
license = "MIT"
license-files = ["LICENSE"]
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]

# Direct (first-order) imports of the `satclip` package. Versions of the heavier
# scientific stack (numpy, sympy, torch, ...) are largely constrained by
# torchgeo/lightning, but are listed explicitly so the package does not silently
# rely on transitive dependencies.
dependencies = [
"albumentations",
"einops",
"huggingface_hub",
"lightning >=2.2,<3",
"matplotlib",
"numpy",
"pandas",
"rasterio >=1.3.10",
"sympy",
"timm",
"torch >=1.13",
"torchgeo >=0.5", # Forces Python 3.9+
"torchvision",
]

[project.optional-dependencies]
# Extra dependencies used by the example notebooks only.
examples = [
"scikit-learn",
]
# Development / CI tooling.
dev = [
"build",
"pytest",
"ruff",
]

[project.scripts]
satclip-train = "satclip.main:cli_main"

[project.urls]
Homepage = "https://github.com/microsoft/satclip"
Repository = "https://github.com/microsoft/satclip.git"
Issues = "https://github.com/microsoft/satclip/issues"

[tool.setuptools.packages.find]
include = ["satclip", "satclip.*"]

[tool.setuptools.package-data]
# Ship the default training config so `python -m satclip.main` works after install.
satclip = ["configs/*.yaml"]
5 changes: 0 additions & 5 deletions satclip/.gitignore

This file was deleted.

16 changes: 11 additions & 5 deletions satclip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from . import *
from .main import *
from .model import *
from .loss import *
from .location_encoder import *
"""SatCLIP: a global, general-purpose geographic location encoder.

See https://github.com/microsoft/satclip for details.
"""

from .load import get_satclip
from .location_encoder import LocationEncoder
from .loss import SatCLIPLoss
from .model import SatCLIP

__all__ = ["SatCLIP", "LocationEncoder", "SatCLIPLoss", "get_satclip"]
4 changes: 2 additions & 2 deletions satclip/load.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from main import *
from .main import *

def get_satclip(ckpt_path, device, return_all=False):
ckpt = torch.load(ckpt_path,map_location=device)
Expand All @@ -15,4 +15,4 @@ def get_satclip(ckpt_path, device, return_all=False):
if return_all:
return geo_model
else:
return geo_model.location
return geo_model.location
10 changes: 5 additions & 5 deletions satclip/load_lightweight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder

from .location_encoder import get_neural_network, get_positional_encoding, LocationEncoder


def get_satclip_loc_encoder(ckpt_path, device):
Expand All @@ -14,7 +15,7 @@ def get_satclip_loc_encoder(ckpt_path, device):
hp['max_radius'],
hp['frequency_num']
)

nnet = get_neural_network(
hp['pe_type'],
posenc.embedding_dim,
Expand All @@ -25,12 +26,11 @@ def get_satclip_loc_encoder(ckpt_path, device):

# only load nnet params from state dict
state_dict = ckpt['state_dict']
state_dict = {k[k.index('nnet'):]:state_dict[k]
state_dict = {k[k.index('nnet'):]:state_dict[k]
for k in state_dict.keys() if 'nnet' in k}

loc_encoder = LocationEncoder(posenc, nnet).double()
loc_encoder.load_state_dict(state_dict)
loc_encoder.eval()

return loc_encoder

10 changes: 5 additions & 5 deletions satclip/location_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from torch import nn, optim
import math

import torch
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from datetime import datetime
import positional_encoding as PE
from torch import nn

from . import positional_encoding as PE

"""
FCNet
Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(self, x, mods = None):
x *= rearrange(mod, 'd -> () d')

return self.last_layer(x)

class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
Expand Down
Loading