1+ import base64
12import logging
3+ import mimetypes
24import os
35import urllib .request
46from typing import Optional
810logger = logging .getLogger (__name__ )
911
1012
13+ def _encode_image_b64 (path : str ) -> str :
14+ """Encode a local image file as a base64 data URL."""
15+ mime , _ = mimetypes .guess_type (path )
16+ mime = mime or "image/png"
17+ with open (path , "rb" ) as f :
18+ data = base64 .b64encode (f .read ()).decode ()
19+ return f"data:{ mime } ;base64,{ data } "
20+
21+
1122class ImageFal (BaseImage ):
12- """fal.ai image generation provider.
23+ """fal.ai image generation and editing provider.
24+
25+ Uses the fal-client SDK to run FLUX and other models on fal.ai.
26+ Requires the ``fal-client`` package: ``pip install fal-client``
1327
14- Uses the fal-client SDK to run FLUX and other models on fal.ai infrastructure.
15- Requires the `fal-client` package: pip install fal-client
28+ Generation (no images):
29+ Runs the configured model with a text prompt.
30+
31+ Editing (images provided):
32+ Switches to the ``image_to_image_model`` endpoint and passes the
33+ first image as ``image_url``. ``strength`` controls how much the
34+ output deviates from the input (0.0 = unchanged, 1.0 = fully
35+ regenerated).
1636
1737 Args:
18- model: The fal model ID to use (default: "fal-ai/flux/schnell").
38+ model: The fal model ID for generation (default: "fal-ai/flux/schnell").
1939 Popular options:
20- "fal-ai/flux/schnell" (fastest, 4 steps)
21- "fal-ai/flux/dev" (higher quality)
22- "fal-ai/flux-pro" (best quality, paid)
23- "fal-ai/flux-pro/v1.1" (latest pro)
24- "fal-ai/flux-lora" (LoRA support)
40+ "fal-ai/flux/schnell" (fastest, 4 steps)
41+ "fal-ai/flux/dev" (higher quality)
42+ "fal-ai/flux-pro" (best quality, paid)
43+ "fal-ai/flux-pro/v1.1"
2544 "fal-ai/stable-diffusion-v3-medium"
26- image_size: Output image size preset (default: "landscape_4_3").
45+ image_to_image_model: Model used when input images are provided
46+ (default: "fal-ai/flux/dev/image-to-image").
47+ Popular options:
48+ "fal-ai/flux/dev/image-to-image"
49+ "fal-ai/flux-pro/v1/redux"
50+ "fal-ai/flux-lora/image-to-image"
51+ image_size: Output size preset for generation (default: "landscape_4_3").
2752 Options: "square_hd", "square", "portrait_4_3", "portrait_16_9",
2853 "landscape_4_3", "landscape_16_9".
2954 num_inference_steps: Steps for generation (default: 4 for schnell).
@@ -33,19 +58,26 @@ class ImageFal(BaseImage):
3358 ```python
3459 from operator_use.providers.fal import ImageFal
3560
36- provider = ImageFal(model="fal-ai/flux/schnell")
61+ provider = ImageFal()
62+
63+ # Generate from scratch
3764 provider.generate("a red panda coding on a laptop", "output.png")
65+
66+ # Edit with a reference image
67+ provider.generate("make it sunset", "output.png", images=["input.png"], strength=0.85)
3868 ```
3969 """
4070
4171 def __init__ (
4272 self ,
4373 model : str = "fal-ai/flux/schnell" ,
74+ image_to_image_model : str = "fal-ai/flux/dev/image-to-image" ,
4475 image_size : str = "landscape_4_3" ,
4576 num_inference_steps : int = 4 ,
4677 api_key : Optional [str ] = None ,
4778 ):
4879 self ._model = model
80+ self .image_to_image_model = image_to_image_model
4981 self .image_size = image_size
5082 self .num_inference_steps = num_inference_steps
5183 self .api_key = api_key or os .environ .get ("FAL_KEY" )
@@ -56,51 +88,56 @@ def __init__(
5688 def model (self ) -> str :
5789 return self ._model
5890
59- def _build_arguments (self , prompt : str , ** kwargs ) -> dict :
60- return {
61- "prompt" : prompt ,
62- "image_size" : kwargs .get ("image_size" , self .image_size ),
63- "num_inference_steps" : kwargs .get ("num_inference_steps" , self .num_inference_steps ),
64- "num_images" : 1 ,
65- "enable_safety_checker" : True ,
66- }
67-
68- def generate (self , prompt : str , output_path : str , ** kwargs ) -> None :
69- """Generate an image and save it to output_path.
70-
71- Args:
72- prompt: Text description of the image to generate.
73- output_path: Path where the image will be saved.
74- **kwargs: Override image_size or num_inference_steps for this call.
75- """
91+ def _build_arguments (self , prompt : str , images : list [str ] | None , ** kwargs ) -> tuple [str , dict ]:
92+ """Return (endpoint, arguments) depending on whether images are provided."""
93+ if images :
94+ endpoint = kwargs .get ("image_to_image_model" , self .image_to_image_model )
95+ args = {
96+ "prompt" : prompt ,
97+ "image_url" : _encode_image_b64 (images [0 ]),
98+ "strength" : kwargs .get ("strength" , 0.85 ),
99+ "num_inference_steps" : kwargs .get ("num_inference_steps" , 28 ),
100+ "num_images" : 1 ,
101+ "enable_safety_checker" : True ,
102+ }
103+ if kwargs .get ("image_size" ):
104+ args ["image_size" ] = kwargs ["image_size" ]
105+ else :
106+ endpoint = self ._model
107+ args = {
108+ "prompt" : prompt ,
109+ "image_size" : kwargs .get ("image_size" , self .image_size ),
110+ "num_inference_steps" : kwargs .get ("num_inference_steps" , self .num_inference_steps ),
111+ "num_images" : 1 ,
112+ "enable_safety_checker" : True ,
113+ }
114+ return endpoint , args
115+
116+ def _save_from_url (self , url : str , output_path : str ) -> None :
117+ urllib .request .urlretrieve (url , output_path )
118+
119+ def generate (self , prompt : str , output_path : str , images : list [str ] | None = None , ** kwargs ) -> None :
76120 try :
77121 import fal_client
78122 except ImportError :
79123 raise ImportError ("fal-client is required: pip install fal-client" )
80124
81- result = fal_client .run (self ._model , arguments = self ._build_arguments (prompt , ** kwargs ))
125+ endpoint , args = self ._build_arguments (prompt , images , ** kwargs )
126+ result = fal_client .run (endpoint , arguments = args )
82127 url = result ["images" ][0 ]["url" ]
83- urllib . request . urlretrieve (url , output_path )
128+ self . _save_from_url (url , output_path )
84129 logger .debug (f"[ImageFal] Image saved to { output_path } " )
85130
86- async def agenerate (self , prompt : str , output_path : str , ** kwargs ) -> None :
87- """Asynchronously generate an image and save it to output_path.
88-
89- Args:
90- prompt: Text description of the image to generate.
91- output_path: Path where the image will be saved.
92- **kwargs: Override image_size or num_inference_steps for this call.
93- """
131+ async def agenerate (self , prompt : str , output_path : str , images : list [str ] | None = None , ** kwargs ) -> None :
94132 try :
95133 import fal_client
96134 except ImportError :
97135 raise ImportError ("fal-client is required: pip install fal-client" )
98136
99137 import aiohttp as _aiohttp
100138
101- result = await fal_client .run_async (
102- self ._model , arguments = self ._build_arguments (prompt , ** kwargs )
103- )
139+ endpoint , args = self ._build_arguments (prompt , images , ** kwargs )
140+ result = await fal_client .run_async (endpoint , arguments = args )
104141 url = result ["images" ][0 ]["url" ]
105142 async with _aiohttp .ClientSession () as session :
106143 async with session .get (url ) as resp :
0 commit comments