| | import asyncio |
| | import os |
| | import threading |
| | import time |
| | import traceback |
| | from pathlib import Path |
| | from typing import Optional, Dict, List |
| |
|
| | import cv2 |
| | import numpy as np |
| | import socketio |
| | import torch |
| |
|
| | try: |
| | torch._C._jit_override_can_fuse_on_cpu(False) |
| | torch._C._jit_override_can_fuse_on_gpu(False) |
| | torch._C._jit_set_texpr_fuser_enabled(False) |
| | torch._C._jit_set_nvfuser_enabled(False) |
| | except: |
| | pass |
| |
|
| |
|
| | import uvicorn |
| | from PIL import Image |
| | from fastapi import APIRouter, FastAPI, Request, UploadFile |
| | from fastapi.encoders import jsonable_encoder |
| | from fastapi.exceptions import HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import JSONResponse, FileResponse, Response |
| | from fastapi.staticfiles import StaticFiles |
| | from loguru import logger |
| | from socketio import AsyncServer |
| |
|
| | from iopaint.file_manager import FileManager |
| | from iopaint.helper import ( |
| | load_img, |
| | decode_base64_to_image, |
| | pil_to_bytes, |
| | numpy_to_bytes, |
| | concat_alpha_channel, |
| | gen_frontend_mask, |
| | adjust_mask, |
| | ) |
| | from iopaint.model.utils import torch_gc |
| | from iopaint.model_manager import ModelManager |
| | from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg |
| | from iopaint.plugins.base_plugin import BasePlugin |
| | from iopaint.plugins.remove_bg import RemoveBG |
| | from iopaint.schema import ( |
| | GenInfoResponse, |
| | ApiConfig, |
| | ServerConfigResponse, |
| | SwitchModelRequest, |
| | InpaintRequest, |
| | RunPluginRequest, |
| | SDSampler, |
| | PluginInfo, |
| | AdjustMaskRequest, |
| | RemoveBGModel, |
| | SwitchPluginModelRequest, |
| | ModelInfo, |
| | InteractiveSegModel, |
| | RealESRGANModel, |
| | ) |
| |
|
| | CURRENT_DIR = Path(__file__).parent.absolute().resolve() |
| | WEB_APP_DIR = CURRENT_DIR / "web_app" |
| |
|
| |
|
| | def api_middleware(app: FastAPI): |
| | rich_available = False |
| | try: |
| | if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: |
| | import anyio |
| | import starlette |
| | from rich.console import Console |
| |
|
| | console = Console() |
| | rich_available = True |
| | except Exception: |
| | pass |
| |
|
| | def handle_exception(request: Request, e: Exception): |
| | err = { |
| | "error": type(e).__name__, |
| | "detail": vars(e).get("detail", ""), |
| | "body": vars(e).get("body", ""), |
| | "errors": str(e), |
| | } |
| | if not isinstance( |
| | e, HTTPException |
| | ): |
| | message = f"API error: {request.method}: {request.url} {err}" |
| | if rich_available: |
| | print(message) |
| | console.print_exception( |
| | show_locals=True, |
| | max_frames=2, |
| | extra_lines=1, |
| | suppress=[anyio, starlette], |
| | word_wrap=False, |
| | width=min([console.width, 200]), |
| | ) |
| | else: |
| | traceback.print_exc() |
| | return JSONResponse( |
| | status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) |
| | ) |
| |
|
| | @app.middleware("http") |
| | async def exception_handling(request: Request, call_next): |
| | try: |
| | return await call_next(request) |
| | except Exception as e: |
| | return handle_exception(request, e) |
| |
|
| | @app.exception_handler(Exception) |
| | async def fastapi_exception_handler(request: Request, e: Exception): |
| | return handle_exception(request, e) |
| |
|
| | @app.exception_handler(HTTPException) |
| | async def http_exception_handler(request: Request, e: HTTPException): |
| | return handle_exception(request, e) |
| |
|
| | cors_options = { |
| | "allow_methods": ["*"], |
| | "allow_headers": ["*"], |
| | "allow_origins": ["*"], |
| | "allow_credentials": True, |
| | } |
| | app.add_middleware(CORSMiddleware, **cors_options) |
| |
|
| |
|
| | global_sio: AsyncServer = None |
| |
|
| |
|
| | def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}): |
| | |
| | |
| |
|
| | |
| | |
| | asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) |
| | return {} |
| |
|
| |
|
| | class Api: |
| | def __init__(self, app: FastAPI, config: ApiConfig): |
| | self.app = app |
| | self.config = config |
| | self.router = APIRouter() |
| | self.queue_lock = threading.Lock() |
| | api_middleware(self.app) |
| |
|
| | self.file_manager = self._build_file_manager() |
| | self.plugins = self._build_plugins() |
| | self.model_manager = self._build_model_manager() |
| |
|
| | |
| | self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) |
| | self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) |
| | self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) |
| | self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) |
| | self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) |
| | self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) |
| | self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) |
| | self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) |
| | self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) |
| | self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) |
| | self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) |
| | self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) |
| | self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") |
| | |
| |
|
| | global global_sio |
| | self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") |
| | self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) |
| | self.app.mount("/ws", self.combined_asgi_app) |
| | global_sio = self.sio |
| |
|
| | def add_api_route(self, path: str, endpoint, **kwargs): |
| | return self.app.add_api_route(path, endpoint, **kwargs) |
| |
|
| | def api_save_image(self, file: UploadFile): |
| | filename = file.filename |
| | origin_image_bytes = file.file.read() |
| | with open(self.config.output_dir / filename, "wb") as fw: |
| | fw.write(origin_image_bytes) |
| |
|
| | def api_current_model(self) -> ModelInfo: |
| | return self.model_manager.current_model |
| |
|
| | def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: |
| | if req.name == self.model_manager.name: |
| | return self.model_manager.current_model |
| | self.model_manager.switch(req.name) |
| | return self.model_manager.current_model |
| |
|
| | def api_switch_plugin_model(self, req: SwitchPluginModelRequest): |
| | if req.plugin_name in self.plugins: |
| | self.plugins[req.plugin_name].switch_model(req.model_name) |
| | if req.plugin_name == RemoveBG.name: |
| | self.config.remove_bg_model = req.model_name |
| | if req.plugin_name == RealESRGANUpscaler.name: |
| | self.config.realesrgan_model = req.model_name |
| | if req.plugin_name == InteractiveSeg.name: |
| | self.config.interactive_seg_model = req.model_name |
| | torch_gc() |
| |
|
| | def api_server_config(self) -> ServerConfigResponse: |
| | plugins = [] |
| | for it in self.plugins.values(): |
| | plugins.append( |
| | PluginInfo( |
| | name=it.name, |
| | support_gen_image=it.support_gen_image, |
| | support_gen_mask=it.support_gen_mask, |
| | ) |
| | ) |
| |
|
| | return ServerConfigResponse( |
| | plugins=plugins, |
| | modelInfos=self.model_manager.scan_models(), |
| | removeBGModel=self.config.remove_bg_model, |
| | removeBGModels=RemoveBGModel.values(), |
| | realesrganModel=self.config.realesrgan_model, |
| | realesrganModels=RealESRGANModel.values(), |
| | interactiveSegModel=self.config.interactive_seg_model, |
| | interactiveSegModels=InteractiveSegModel.values(), |
| | enableFileManager=self.file_manager is not None, |
| | enableAutoSaving=self.config.output_dir is not None, |
| | enableControlnet=self.model_manager.enable_controlnet, |
| | controlnetMethod=self.model_manager.controlnet_method, |
| | disableModelSwitch=False, |
| | isDesktop=False, |
| | samplers=self.api_samplers(), |
| | ) |
| |
|
| | def api_input_image(self) -> FileResponse: |
| | if self.config.input and self.config.input.is_file(): |
| | return FileResponse(self.config.input) |
| | raise HTTPException(status_code=404, detail="Input image not found") |
| |
|
| | def api_geninfo(self, file: UploadFile) -> GenInfoResponse: |
| | _, _, info = load_img(file.file.read(), return_info=True) |
| | parts = info.get("parameters", "").split("Negative prompt: ") |
| | prompt = parts[0].strip() |
| | negative_prompt = "" |
| | if len(parts) > 1: |
| | negative_prompt = parts[1].split("\n")[0].strip() |
| | return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) |
| |
|
| | def api_inpaint(self, req: InpaintRequest): |
| | image, alpha_channel, infos = decode_base64_to_image(req.image) |
| | mask, _, _ = decode_base64_to_image(req.mask, gray=True) |
| |
|
| | mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] |
| | if image.shape[:2] != mask.shape[:2]: |
| | raise HTTPException( |
| | 400, |
| | detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", |
| | ) |
| |
|
| | if req.paint_by_example_example_image: |
| | paint_by_example_image, _, _ = decode_base64_to_image( |
| | req.paint_by_example_example_image |
| | ) |
| |
|
| | start = time.time() |
| | rgb_np_img = self.model_manager(image, mask, req) |
| | logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") |
| | torch_gc() |
| |
|
| | rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) |
| | rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) |
| |
|
| | ext = "png" |
| | res_img_bytes = pil_to_bytes( |
| | Image.fromarray(rgb_res), |
| | ext=ext, |
| | quality=self.config.quality, |
| | infos=infos, |
| | ) |
| |
|
| | asyncio.run(self.sio.emit("diffusion_finish")) |
| |
|
| | return Response( |
| | content=res_img_bytes, |
| | media_type=f"image/{ext}", |
| | headers={"X-Seed": str(req.sd_seed)}, |
| | ) |
| |
|
| | def api_run_plugin_gen_image(self, req: RunPluginRequest): |
| | ext = "png" |
| | if req.name not in self.plugins: |
| | raise HTTPException(status_code=422, detail="Plugin not found") |
| | if not self.plugins[req.name].support_gen_image: |
| | raise HTTPException( |
| | status_code=422, detail="Plugin does not support output image" |
| | ) |
| | rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) |
| | bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) |
| | torch_gc() |
| |
|
| | if bgr_or_rgba_np_img.shape[2] == 4: |
| | rgba_np_img = bgr_or_rgba_np_img |
| | else: |
| | rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) |
| | rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) |
| |
|
| | return Response( |
| | content=pil_to_bytes( |
| | Image.fromarray(rgba_np_img), |
| | ext=ext, |
| | quality=self.config.quality, |
| | infos=infos, |
| | ), |
| | media_type=f"image/{ext}", |
| | ) |
| |
|
| | def api_run_plugin_gen_mask(self, req: RunPluginRequest): |
| | if req.name not in self.plugins: |
| | raise HTTPException(status_code=422, detail="Plugin not found") |
| | if not self.plugins[req.name].support_gen_mask: |
| | raise HTTPException( |
| | status_code=422, detail="Plugin does not support output image" |
| | ) |
| | rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) |
| | bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) |
| | torch_gc() |
| | res_mask = gen_frontend_mask(bgr_or_gray_mask) |
| | return Response( |
| | content=numpy_to_bytes(res_mask, "png"), |
| | media_type="image/png", |
| | ) |
| |
|
| | def api_samplers(self) -> List[str]: |
| | return [member.value for member in SDSampler.__members__.values()] |
| |
|
| | def api_adjust_mask(self, req: AdjustMaskRequest): |
| | mask, _, _ = decode_base64_to_image(req.mask, gray=True) |
| | mask = adjust_mask(mask, req.kernel_size, req.operate) |
| | return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") |
| |
|
| | def launch(self): |
| | self.app.include_router(self.router) |
| | uvicorn.run( |
| | self.combined_asgi_app, |
| | host=self.config.host, |
| | port=self.config.port, |
| | timeout_keep_alive=999999999, |
| | ) |
| |
|
| | def _build_file_manager(self) -> Optional[FileManager]: |
| | if self.config.input and self.config.input.is_dir(): |
| | logger.info( |
| | f"Input is directory, initialize file manager {self.config.input}" |
| | ) |
| |
|
| | return FileManager( |
| | app=self.app, |
| | input_dir=self.config.input, |
| | output_dir=self.config.output_dir, |
| | ) |
| | return None |
| |
|
| | def _build_plugins(self) -> Dict[str, BasePlugin]: |
| | return build_plugins( |
| | self.config.enable_interactive_seg, |
| | self.config.interactive_seg_model, |
| | self.config.interactive_seg_device, |
| | self.config.enable_remove_bg, |
| | self.config.remove_bg_model, |
| | self.config.enable_anime_seg, |
| | self.config.enable_realesrgan, |
| | self.config.realesrgan_device, |
| | self.config.realesrgan_model, |
| | self.config.enable_gfpgan, |
| | self.config.gfpgan_device, |
| | self.config.enable_restoreformer, |
| | self.config.restoreformer_device, |
| | self.config.no_half, |
| | ) |
| |
|
| | def _build_model_manager(self): |
| | return ModelManager( |
| | name=self.config.model, |
| | device=torch.device(self.config.device), |
| | no_half=self.config.no_half, |
| | low_mem=self.config.low_mem, |
| | disable_nsfw=self.config.disable_nsfw_checker, |
| | sd_cpu_textencoder=self.config.cpu_textencoder, |
| | local_files_only=self.config.local_files_only, |
| | cpu_offload=self.config.cpu_offload, |
| | callback=diffuser_callback, |
| | ) |
| |
|