123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from PIL import Image as PILImage
|
|
|
|
try:
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
except ImportError: # pragma: no cover - exercised in tests via runtime guard
|
|
genai = None
|
|
genai_types = None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GeminiImageService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
model: str = "gemini-3-pro-image-preview",
|
|
client: Any | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self._client = client or self._build_client()
|
|
|
|
def _build_client(self) -> Any:
|
|
if genai is None:
|
|
raise RuntimeError("google-genai is not installed. Install it to use GeminiImageService.")
|
|
return genai.Client(api_key=self.api_key)
|
|
|
|
def _build_config(self, *, aspect_ratio: str, image_size: str) -> Any:
|
|
if genai_types is None:
|
|
return {
|
|
"response_modalities": ["TEXT", "IMAGE"],
|
|
"image_config": {
|
|
"aspect_ratio": aspect_ratio,
|
|
"image_size": image_size,
|
|
},
|
|
}
|
|
|
|
image_config_factory = getattr(genai_types, "ImageConfig", None)
|
|
generate_config_factory = getattr(genai_types, "GenerateContentConfig", None)
|
|
image_config = (
|
|
image_config_factory(aspect_ratio=aspect_ratio, image_size=image_size)
|
|
if callable(image_config_factory)
|
|
else {
|
|
"aspect_ratio": aspect_ratio,
|
|
"image_size": image_size,
|
|
}
|
|
)
|
|
|
|
if callable(generate_config_factory):
|
|
return generate_config_factory(
|
|
response_modalities=["TEXT", "IMAGE"],
|
|
image_config=image_config,
|
|
)
|
|
|
|
return {
|
|
"response_modalities": ["TEXT", "IMAGE"],
|
|
"image_config": image_config,
|
|
}
|
|
|
|
@staticmethod
|
|
def _extract_parts(response: Any) -> list[Any]:
|
|
direct_parts = getattr(response, "parts", None)
|
|
if direct_parts:
|
|
return list(direct_parts)
|
|
|
|
candidates = getattr(response, "candidates", None) or []
|
|
for candidate in candidates:
|
|
candidate_parts = getattr(getattr(candidate, "content", None), "parts", None)
|
|
if candidate_parts:
|
|
return list(candidate_parts)
|
|
return []
|
|
|
|
def generate_image(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
output_path: str,
|
|
reference_image_path: str | None = None,
|
|
aspect_ratio: str = "16:9",
|
|
image_size: str = "2K",
|
|
) -> dict[str, str]:
|
|
reference_image: PILImage.Image | None = None
|
|
|
|
try:
|
|
contents: list[Any] = [prompt]
|
|
if reference_image_path:
|
|
reference_image = PILImage.open(reference_image_path)
|
|
contents.append(reference_image)
|
|
|
|
response = self._client.models.generate_content(
|
|
model=self.model,
|
|
contents=contents,
|
|
config=self._build_config(aspect_ratio=aspect_ratio, image_size=image_size),
|
|
)
|
|
|
|
for part in self._extract_parts(response):
|
|
if getattr(part, "inline_data", None) is not None and hasattr(part, "as_image"):
|
|
output = Path(output_path)
|
|
output.parent.mkdir(parents=True, exist_ok=True)
|
|
part.as_image().save(str(output))
|
|
return {"status": "success", "path": str(output)}
|
|
|
|
text_parts = [str(part.text).strip() for part in self._extract_parts(response) if getattr(part, "text", None)]
|
|
message = "No image in API response."
|
|
if text_parts:
|
|
message = f"{message} {' '.join(text_parts)}"
|
|
return {"status": "error", "error": message}
|
|
except Exception as exc:
|
|
logger.exception("Gemini image generation failed.")
|
|
return {"status": "error", "error": str(exc)}
|
|
finally:
|
|
if reference_image is not None:
|
|
reference_image.close()
|