Skip to content

Commit 69207b5

Browse files
Jeomonclaude
andcommitted
Add image editing support (images param) to all image providers
BaseImage protocol updated: generate/agenerate now accept optional images: list[str] for image-to-image editing. OpenAI: gpt-image-1 accepts up to 16 reference images via images.edit(), dall-e-2 uses first image as source + second as optional mask, dall-e-3 raises ValueError (not supported). Google: uses edit_image() with RawReferenceImage objects (Vertex AI required for editing; generation still uses standard API key). Together: encodes first image as base64 data URL passed via extra_body image_url with a strength param for img2img capable models. fal.ai: switches to image_to_image_model endpoint when images provided, encodes first image as base64 data URL with configurable strength. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b582296 commit 69207b5

5 files changed

Lines changed: 348 additions & 167 deletions

File tree

operator_use/providers/base.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,31 @@ def model(self) -> str:
158158
"""The name of the image generation model being used."""
159159
...
160160

161-
def generate(self, prompt: str, output_path: str, **kwargs) -> None:
162-
"""Generate an image from a text prompt and save it to a file.
161+
def generate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
162+
"""Generate or edit an image and save it to a file.
163163
164164
Args:
165-
prompt: The text description to generate an image from.
165+
prompt: Text description of the image to generate or the edit to apply.
166166
output_path: Path where the generated image file will be saved.
167-
**kwargs: Provider-specific parameters (size, quality, style, etc.).
167+
images: Optional list of input image file paths. When provided, the
168+
provider edits or uses these as references rather than generating
169+
from scratch. Behaviour is provider-specific:
170+
- OpenAI gpt-image-1: up to 16 reference images
171+
- OpenAI dall-e-2: first image as source, second as mask (optional)
172+
- Google Imagen: first image as reference (Vertex AI required)
173+
- Together AI / fal.ai: first image used as img2img source
174+
**kwargs: Provider-specific parameters (size, quality, style, strength, etc.).
168175
"""
169176
...
170177

171-
async def agenerate(self, prompt: str, output_path: str, **kwargs) -> None:
172-
"""Asynchronously generate an image from a text prompt and save it to a file.
178+
async def agenerate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
179+
"""Asynchronously generate or edit an image and save it to a file.
173180
174181
Args:
175-
prompt: The text description to generate an image from.
182+
prompt: Text description of the image to generate or the edit to apply.
176183
output_path: Path where the generated image file will be saved.
177-
**kwargs: Provider-specific parameters (size, quality, style, etc.).
184+
images: Optional list of input image file paths. See generate() for details.
185+
**kwargs: Provider-specific parameters (size, quality, style, strength, etc.).
178186
"""
179187
...
180188

operator_use/providers/fal/image.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import base64
12
import logging
3+
import mimetypes
24
import os
35
import urllib.request
46
from typing import Optional
@@ -8,22 +10,45 @@
810
logger = logging.getLogger(__name__)
911

1012

13+
def _encode_image_b64(path: str) -> str:
14+
"""Encode a local image file as a base64 data URL."""
15+
mime, _ = mimetypes.guess_type(path)
16+
mime = mime or "image/png"
17+
with open(path, "rb") as f:
18+
data = base64.b64encode(f.read()).decode()
19+
return f"data:{mime};base64,{data}"
20+
21+
1122
class ImageFal(BaseImage):
12-
"""fal.ai image generation provider.
23+
"""fal.ai image generation and editing provider.
24+
25+
Uses the fal-client SDK to run FLUX and other models on fal.ai.
26+
Requires the ``fal-client`` package: ``pip install fal-client``
1327
14-
Uses the fal-client SDK to run FLUX and other models on fal.ai infrastructure.
15-
Requires the `fal-client` package: pip install fal-client
28+
Generation (no images):
29+
Runs the configured model with a text prompt.
30+
31+
Editing (images provided):
32+
Switches to the ``image_to_image_model`` endpoint and passes the
33+
first image as ``image_url``. ``strength`` controls how much the
34+
output deviates from the input (0.0 = unchanged, 1.0 = fully
35+
regenerated).
1636
1737
Args:
18-
model: The fal model ID to use (default: "fal-ai/flux/schnell").
38+
model: The fal model ID for generation (default: "fal-ai/flux/schnell").
1939
Popular options:
20-
"fal-ai/flux/schnell" (fastest, 4 steps)
21-
"fal-ai/flux/dev" (higher quality)
22-
"fal-ai/flux-pro" (best quality, paid)
23-
"fal-ai/flux-pro/v1.1" (latest pro)
24-
"fal-ai/flux-lora" (LoRA support)
40+
"fal-ai/flux/schnell" (fastest, 4 steps)
41+
"fal-ai/flux/dev" (higher quality)
42+
"fal-ai/flux-pro" (best quality, paid)
43+
"fal-ai/flux-pro/v1.1"
2544
"fal-ai/stable-diffusion-v3-medium"
26-
image_size: Output image size preset (default: "landscape_4_3").
45+
image_to_image_model: Model used when input images are provided
46+
(default: "fal-ai/flux/dev/image-to-image").
47+
Popular options:
48+
"fal-ai/flux/dev/image-to-image"
49+
"fal-ai/flux-pro/v1/redux"
50+
"fal-ai/flux-lora/image-to-image"
51+
image_size: Output size preset for generation (default: "landscape_4_3").
2752
Options: "square_hd", "square", "portrait_4_3", "portrait_16_9",
2853
"landscape_4_3", "landscape_16_9".
2954
num_inference_steps: Steps for generation (default: 4 for schnell).
@@ -33,19 +58,26 @@ class ImageFal(BaseImage):
3358
```python
3459
from operator_use.providers.fal import ImageFal
3560
36-
provider = ImageFal(model="fal-ai/flux/schnell")
61+
provider = ImageFal()
62+
63+
# Generate from scratch
3764
provider.generate("a red panda coding on a laptop", "output.png")
65+
66+
# Edit with a reference image
67+
provider.generate("make it sunset", "output.png", images=["input.png"], strength=0.85)
3868
```
3969
"""
4070

4171
def __init__(
4272
self,
4373
model: str = "fal-ai/flux/schnell",
74+
image_to_image_model: str = "fal-ai/flux/dev/image-to-image",
4475
image_size: str = "landscape_4_3",
4576
num_inference_steps: int = 4,
4677
api_key: Optional[str] = None,
4778
):
4879
self._model = model
80+
self.image_to_image_model = image_to_image_model
4981
self.image_size = image_size
5082
self.num_inference_steps = num_inference_steps
5183
self.api_key = api_key or os.environ.get("FAL_KEY")
@@ -56,51 +88,56 @@ def __init__(
5688
def model(self) -> str:
5789
return self._model
5890

59-
def _build_arguments(self, prompt: str, **kwargs) -> dict:
60-
return {
61-
"prompt": prompt,
62-
"image_size": kwargs.get("image_size", self.image_size),
63-
"num_inference_steps": kwargs.get("num_inference_steps", self.num_inference_steps),
64-
"num_images": 1,
65-
"enable_safety_checker": True,
66-
}
67-
68-
def generate(self, prompt: str, output_path: str, **kwargs) -> None:
69-
"""Generate an image and save it to output_path.
70-
71-
Args:
72-
prompt: Text description of the image to generate.
73-
output_path: Path where the image will be saved.
74-
**kwargs: Override image_size or num_inference_steps for this call.
75-
"""
91+
def _build_arguments(self, prompt: str, images: list[str] | None, **kwargs) -> tuple[str, dict]:
92+
"""Return (endpoint, arguments) depending on whether images are provided."""
93+
if images:
94+
endpoint = kwargs.get("image_to_image_model", self.image_to_image_model)
95+
args = {
96+
"prompt": prompt,
97+
"image_url": _encode_image_b64(images[0]),
98+
"strength": kwargs.get("strength", 0.85),
99+
"num_inference_steps": kwargs.get("num_inference_steps", 28),
100+
"num_images": 1,
101+
"enable_safety_checker": True,
102+
}
103+
if kwargs.get("image_size"):
104+
args["image_size"] = kwargs["image_size"]
105+
else:
106+
endpoint = self._model
107+
args = {
108+
"prompt": prompt,
109+
"image_size": kwargs.get("image_size", self.image_size),
110+
"num_inference_steps": kwargs.get("num_inference_steps", self.num_inference_steps),
111+
"num_images": 1,
112+
"enable_safety_checker": True,
113+
}
114+
return endpoint, args
115+
116+
def _save_from_url(self, url: str, output_path: str) -> None:
117+
urllib.request.urlretrieve(url, output_path)
118+
119+
def generate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
76120
try:
77121
import fal_client
78122
except ImportError:
79123
raise ImportError("fal-client is required: pip install fal-client")
80124

81-
result = fal_client.run(self._model, arguments=self._build_arguments(prompt, **kwargs))
125+
endpoint, args = self._build_arguments(prompt, images, **kwargs)
126+
result = fal_client.run(endpoint, arguments=args)
82127
url = result["images"][0]["url"]
83-
urllib.request.urlretrieve(url, output_path)
128+
self._save_from_url(url, output_path)
84129
logger.debug(f"[ImageFal] Image saved to {output_path}")
85130

86-
async def agenerate(self, prompt: str, output_path: str, **kwargs) -> None:
87-
"""Asynchronously generate an image and save it to output_path.
88-
89-
Args:
90-
prompt: Text description of the image to generate.
91-
output_path: Path where the image will be saved.
92-
**kwargs: Override image_size or num_inference_steps for this call.
93-
"""
131+
async def agenerate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
94132
try:
95133
import fal_client
96134
except ImportError:
97135
raise ImportError("fal-client is required: pip install fal-client")
98136

99137
import aiohttp as _aiohttp
100138

101-
result = await fal_client.run_async(
102-
self._model, arguments=self._build_arguments(prompt, **kwargs)
103-
)
139+
endpoint, args = self._build_arguments(prompt, images, **kwargs)
140+
result = await fal_client.run_async(endpoint, arguments=args)
104141
url = result["images"][0]["url"]
105142
async with _aiohttp.ClientSession() as session:
106143
async with session.get(url) as resp:

operator_use/providers/google/image.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,59 @@
99

1010

1111
class ImageGoogle(BaseImage):
12-
"""Google Imagen image generation provider.
12+
"""Google Imagen image generation and editing provider.
1313
14-
Uses the Google GenAI SDK (Imagen 3) to generate images from text prompts.
14+
Uses the Google GenAI SDK for text-to-image generation (Imagen 3) and
15+
image editing (requires Vertex AI credentials).
16+
17+
Generation (no images):
18+
Uses ``imagen-3.0-generate-002`` via the standard GenAI API key.
19+
20+
Editing (images provided):
21+
Uses ``models.edit_image()`` with Vertex AI — requires
22+
``GOOGLE_CLOUD_PROJECT`` and ``GOOGLE_CLOUD_LOCATION`` environment
23+
variables in addition to the API key, and model
24+
``imagen-3.0-capability-001``.
1525
1626
Args:
17-
model: The Imagen model to use (default: "imagen-3.0-generate-002").
27+
model: Generation model (default: "imagen-3.0-generate-002").
28+
edit_model: Editing model (default: "imagen-3.0-capability-001").
1829
api_key: Google API key. Falls back to GEMINI_API_KEY env variable.
19-
negative_prompt: Optional description of what to exclude from the image.
30+
negative_prompt: Optional description of what to exclude.
31+
project: Google Cloud project ID for Vertex AI editing.
32+
Falls back to GOOGLE_CLOUD_PROJECT env variable.
33+
location: Google Cloud location for Vertex AI editing.
34+
Falls back to GOOGLE_CLOUD_LOCATION env variable (default: "us-central1").
2035
2136
Example:
2237
```python
2338
from operator_use.providers.google import ImageGoogle
2439
40+
# Generation (standard API key)
2541
provider = ImageGoogle()
2642
provider.generate("a red panda coding on a laptop", "output.png")
43+
44+
# Editing (Vertex AI)
45+
provider = ImageGoogle(project="my-project")
46+
provider.generate("make it sunset", "output.png", images=["input.png"])
2747
```
2848
"""
2949

3050
def __init__(
3151
self,
3252
model: str = "imagen-3.0-generate-002",
53+
edit_model: str = "imagen-3.0-capability-001",
3354
api_key: Optional[str] = None,
3455
negative_prompt: Optional[str] = None,
56+
project: Optional[str] = None,
57+
location: Optional[str] = None,
3558
):
3659
self._model = model
60+
self.edit_model = edit_model
3761
self.negative_prompt = negative_prompt
3862
self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
63+
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
64+
self.location = location or os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
3965

4066
@property
4167
def model(self) -> str:
@@ -45,38 +71,57 @@ def _make_client(self):
4571
from google import genai
4672
return genai.Client(api_key=self.api_key)
4773

48-
def generate(self, prompt: str, output_path: str, **kwargs) -> None:
49-
"""Generate an image and save it to output_path.
50-
51-
Args:
52-
prompt: Text description of the image to generate.
53-
output_path: Path where the PNG image will be saved.
54-
**kwargs: Override negative_prompt for this call.
55-
"""
74+
def _make_vertex_client(self):
5675
from google import genai
76+
if not self.project:
77+
raise ValueError(
78+
"Google image editing requires a Vertex AI project. "
79+
"Set GOOGLE_CLOUD_PROJECT env variable or pass project= to ImageGoogle()."
80+
)
81+
return genai.Client(vertexai=True, project=self.project, location=self.location)
82+
83+
def generate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
84+
from google import genai
85+
86+
if images:
87+
client = self._make_vertex_client()
88+
reference_images = [
89+
genai.types.RawReferenceImage(
90+
reference_id=i + 1,
91+
reference_image=genai.types.Image.from_file(path),
92+
)
93+
for i, path in enumerate(images)
94+
]
95+
config = genai.types.EditImageConfig(
96+
edit_mode=kwargs.get("edit_mode", "EDIT_MODE_DEFAULT"),
97+
number_of_images=1,
98+
output_mime_type="image/png",
99+
negative_prompt=kwargs.get("negative_prompt", self.negative_prompt),
100+
)
101+
response = client.models.edit_image(
102+
model=self.edit_model,
103+
prompt=prompt,
104+
reference_images=reference_images,
105+
config=config,
106+
)
107+
image_bytes = response.generated_images[0].image.image_bytes
108+
else:
109+
client = self._make_client()
110+
config = genai.types.GenerateImagesConfig(
111+
number_of_images=1,
112+
output_mime_type="image/png",
113+
negative_prompt=kwargs.get("negative_prompt", self.negative_prompt),
114+
)
115+
response = client.models.generate_images(
116+
model=self._model,
117+
prompt=prompt,
118+
config=config,
119+
)
120+
image_bytes = response.generated_images[0].image.image_bytes
57121

58-
client = self._make_client()
59-
config = genai.types.GenerateImagesConfig(
60-
number_of_images=1,
61-
output_mime_type="image/png",
62-
negative_prompt=kwargs.get("negative_prompt", self.negative_prompt),
63-
)
64-
response = client.models.generate_images(
65-
model=self._model,
66-
prompt=prompt,
67-
config=config,
68-
)
69-
image_data = response.generated_images[0].image.image_data
70122
with open(output_path, "wb") as f:
71-
f.write(image_data)
123+
f.write(image_bytes)
72124
logger.debug(f"[ImageGoogle] Image saved to {output_path}")
73125

74-
async def agenerate(self, prompt: str, output_path: str, **kwargs) -> None:
75-
"""Asynchronously generate an image and save it to output_path.
76-
77-
Args:
78-
prompt: Text description of the image to generate.
79-
output_path: Path where the PNG image will be saved.
80-
**kwargs: Override negative_prompt for this call.
81-
"""
82-
await asyncio.to_thread(self.generate, prompt, output_path, **kwargs)
126+
async def agenerate(self, prompt: str, output_path: str, images: list[str] | None = None, **kwargs) -> None:
127+
await asyncio.to_thread(self.generate, prompt, output_path, images, **kwargs)

0 commit comments

Comments
 (0)