Files
2026-04-04 19:29:27 +09:00

123 lines
4.1 KiB
Python

from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from PIL import Image as PILImage
try:
from google import genai
from google.genai import types as genai_types
except ImportError: # pragma: no cover - exercised in tests via runtime guard
genai = None
genai_types = None
logger = logging.getLogger(__name__)
class GeminiImageService:
def __init__(
self,
*,
api_key: str,
model: str = "gemini-3-pro-image-preview",
client: Any | None = None,
) -> None:
self.api_key = api_key
self.model = model
self._client = client or self._build_client()
def _build_client(self) -> Any:
if genai is None:
raise RuntimeError("google-genai is not installed. Install it to use GeminiImageService.")
return genai.Client(api_key=self.api_key)
def _build_config(self, *, aspect_ratio: str, image_size: str) -> Any:
if genai_types is None:
return {
"response_modalities": ["TEXT", "IMAGE"],
"image_config": {
"aspect_ratio": aspect_ratio,
"image_size": image_size,
},
}
image_config_factory = getattr(genai_types, "ImageConfig", None)
generate_config_factory = getattr(genai_types, "GenerateContentConfig", None)
image_config = (
image_config_factory(aspect_ratio=aspect_ratio, image_size=image_size)
if callable(image_config_factory)
else {
"aspect_ratio": aspect_ratio,
"image_size": image_size,
}
)
if callable(generate_config_factory):
return generate_config_factory(
response_modalities=["TEXT", "IMAGE"],
image_config=image_config,
)
return {
"response_modalities": ["TEXT", "IMAGE"],
"image_config": image_config,
}
@staticmethod
def _extract_parts(response: Any) -> list[Any]:
direct_parts = getattr(response, "parts", None)
if direct_parts:
return list(direct_parts)
candidates = getattr(response, "candidates", None) or []
for candidate in candidates:
candidate_parts = getattr(getattr(candidate, "content", None), "parts", None)
if candidate_parts:
return list(candidate_parts)
return []
def generate_image(
self,
*,
prompt: str,
output_path: str,
reference_image_path: str | None = None,
aspect_ratio: str = "16:9",
image_size: str = "2K",
) -> dict[str, str]:
reference_image: PILImage.Image | None = None
try:
contents: list[Any] = [prompt]
if reference_image_path:
reference_image = PILImage.open(reference_image_path)
contents.append(reference_image)
response = self._client.models.generate_content(
model=self.model,
contents=contents,
config=self._build_config(aspect_ratio=aspect_ratio, image_size=image_size),
)
for part in self._extract_parts(response):
if getattr(part, "inline_data", None) is not None and hasattr(part, "as_image"):
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
part.as_image().save(str(output))
return {"status": "success", "path": str(output)}
text_parts = [str(part.text).strip() for part in self._extract_parts(response) if getattr(part, "text", None)]
message = "No image in API response."
if text_parts:
message = f"{message} {' '.join(text_parts)}"
return {"status": "error", "error": message}
except Exception as exc:
logger.exception("Gemini image generation failed.")
return {"status": "error", "error": str(exc)}
finally:
if reference_image is not None:
reference_image.close()