| import numpy as np | |
| from typing import List | |
| from .common import CommonOCR, OfflineOCR | |
| from .model_32px import Model32pxOCR | |
| from .model_48px import Model48pxOCR | |
| from .model_48px_ctc import Model48pxCTCOCR | |
| from .model_manga_ocr import ModelMangaOCR | |
| from ..utils import Quadrilateral | |
| OCRS = { | |
| '32px': Model32pxOCR, | |
| '48px': Model48pxOCR, | |
| '48px_ctc': Model48pxCTCOCR, | |
| 'mocr': ModelMangaOCR, | |
| } | |
| ocr_cache = {} | |
| def get_ocr(key: str, *args, **kwargs) -> CommonOCR: | |
| if key not in OCRS: | |
| raise ValueError(f'Could not find OCR for: "{key}". Choose from the following: %s' % ','.join(OCRS)) | |
| if not ocr_cache.get(key): | |
| ocr = OCRS[key] | |
| ocr_cache[key] = ocr(*args, **kwargs) | |
| return ocr_cache[key] | |
| async def prepare(ocr_key: str, device: str = 'cpu'): | |
| ocr = get_ocr(ocr_key) | |
| if isinstance(ocr, OfflineOCR): | |
| await ocr.download() | |
| await ocr.load(device) | |
| async def dispatch(ocr_key: str, image: np.ndarray, regions: List[Quadrilateral], args = None, device: str = 'cpu', verbose: bool = False) -> List[Quadrilateral]: | |
| ocr = get_ocr(ocr_key) | |
| if isinstance(ocr, OfflineOCR): | |
| await ocr.load(device) | |
| args = args or {} | |
| return await ocr.recognize(image, regions, args, verbose) | |