| """ |
| Image processor for Gemma3Tiled that tiles images into grids. |
| |
| Instead of resizing images to a fixed size or using pan-and-scan crops, |
| this processor tiles the image into a grid of 896x896 patches that |
| preserves the spatial layout. |
| """ |
|
|
| import math |
| from typing import Optional, Union |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers.image_utils import ( |
| ImageInput, |
| make_flat_list_of_images, |
| valid_images, |
| infer_channel_dimension_format, |
| to_numpy_array, |
| ChannelDimension, |
| ) |
| from transformers.utils import TensorType |
|
|
|
|
| def calculate_tile_grid( |
| image_height: int, |
| image_width: int, |
| tile_size: int, |
| max_tiles_h: int, |
| max_tiles_w: int, |
| min_tiles: int = 1, |
| ) -> tuple[int, int]: |
| """ |
| Calculate the optimal tile grid dimensions for an image. |
| |
| The strategy is to: |
| 1. Maximize effective resolution (pixels preserved from original image) |
| 2. Minimize wasted canvas space as a tiebreaker |
| |
| Upscaling is not credited - effective resolution is capped at original image size. |
| This means larger grids are only chosen if they preserve more original detail. |
| |
| Example: For a 1344x912 image with 896x896 tiles: |
| | Canvas | Scale | Effective | Wasted | |
| |-------------|-------|-------------|-----------| |
| | 1×1 (896²) | 0.667 | 544,768 | 258,048 | |
| | 1×2 | 0.982 | 1,182,720 | 422,912 | |
| | 2×1 | 0.667 | 544,768 | 1,060,864 | |
| | 2×2 | 1.333 | 1,225,728 ✓ | 1,985,536 | |
| |
| Winner: 2×2 (highest effective resolution = 100% of original pixels) |
| |
| Args: |
| image_height: Original image height |
| image_width: Original image width |
| tile_size: Size of each tile (896) |
| max_tiles_h: Maximum tiles in height |
| max_tiles_w: Maximum tiles in width |
| min_tiles: Minimum total tiles |
| |
| Returns: |
| (grid_h, grid_w): Number of tiles in height and width |
| """ |
| original_pixels = image_height * image_width |
| |
| best_grid = (1, 1) |
| best_score = float('-inf') |
| |
| |
| for rows in range(1, max_tiles_h + 1): |
| for cols in range(1, max_tiles_w + 1): |
| total_tiles = rows * cols |
| |
| |
| if total_tiles < min_tiles: |
| continue |
| |
| |
| canvas_h = rows * tile_size |
| canvas_w = cols * tile_size |
| |
| |
| scale = min(canvas_w / image_width, canvas_h / image_height) |
| |
| |
| |
| effective = min(image_height * image_width * scale * scale, original_pixels) |
| |
| |
| waste = (canvas_h * canvas_w) - effective |
| |
| |
| score = effective - 0.001 * waste |
| |
| if score > best_score: |
| best_score = score |
| best_grid = (rows, cols) |
| |
| return best_grid |
|
|
|
|
| def tile_image( |
| image: np.ndarray, |
| tile_size: int, |
| grid_h: int, |
| grid_w: int, |
| resample: Image.Resampling = Image.Resampling.BICUBIC, |
| ) -> np.ndarray: |
| """ |
| Tile an image into a grid of fixed-size patches. |
| |
| The image is first resized so that when divided into grid_h x grid_w tiles, |
| each tile is exactly tile_size x tile_size. |
| |
| Args: |
| image: Input image as numpy array (H, W, C) or (C, H, W) |
| tile_size: Size of each tile |
| grid_h: Number of tiles in height |
| grid_w: Number of tiles in width |
| resample: PIL resampling method |
| |
| Returns: |
| Tiled image array of shape (grid_h * grid_w, C, tile_size, tile_size) |
| """ |
| |
| if image.ndim == 3: |
| if image.shape[0] in [1, 3, 4]: |
| image = np.transpose(image, (1, 2, 0)) |
| |
| |
| if np.issubdtype(image.dtype, np.floating) and image.max() <= 1.0: |
| image = (image * 255).astype(np.uint8) |
| else: |
| image = image.astype(np.uint8) |
| pil_image = Image.fromarray(image) |
| |
| |
| target_h = grid_h * tile_size |
| target_w = grid_w * tile_size |
| |
| |
| pil_image = pil_image.resize((target_w, target_h), resample=resample) |
| |
| |
| image = np.array(pil_image) |
| |
| |
| |
| tiles = [] |
| for i in range(grid_h): |
| for j in range(grid_w): |
| y_start = i * tile_size |
| x_start = j * tile_size |
| tile = image[y_start:y_start + tile_size, x_start:x_start + tile_size] |
| |
| tile = np.transpose(tile, (2, 0, 1)) |
| tiles.append(tile) |
| |
| return np.stack(tiles, axis=0) |
|
|
|
|
| class Gemma3TiledImageProcessor(BaseImageProcessor): |
| """ |
| Image processor for Gemma3Tiled that tiles images into grids. |
| |
| This processor: |
| 1. Calculates the optimal tile grid for each image |
| 2. Resizes and tiles the image |
| 3. Returns pixel_values and tile_grid_shape metadata |
| """ |
| |
| model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"] |
| _auto_class = "AutoImageProcessor" |
| |
| def __init__( |
| self, |
| tile_size: int = 896, |
| max_tiles_h: int = 4, |
| max_tiles_w: int = 4, |
| min_tiles: int = 1, |
| do_rescale: bool = True, |
| rescale_factor: float = 1 / 255, |
| do_normalize: bool = True, |
| image_mean: Optional[list[float]] = None, |
| image_std: Optional[list[float]] = None, |
| do_convert_rgb: bool = True, |
| resample: Image.Resampling = Image.Resampling.BICUBIC, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| |
| self.tile_size = tile_size |
| self.max_tiles_h = max_tiles_h |
| self.max_tiles_w = max_tiles_w |
| self.min_tiles = min_tiles |
| self.do_rescale = do_rescale |
| self.rescale_factor = rescale_factor |
| self.do_normalize = do_normalize |
| self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] |
| self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] |
| self.do_convert_rgb = do_convert_rgb |
| self.resample = resample |
| |
| def preprocess( |
| self, |
| images: ImageInput, |
| tile_size: Optional[int] = None, |
| max_tiles_h: Optional[int] = None, |
| max_tiles_w: Optional[int] = None, |
| min_tiles: Optional[int] = None, |
| do_rescale: Optional[bool] = None, |
| rescale_factor: Optional[float] = None, |
| do_normalize: Optional[bool] = None, |
| image_mean: Optional[list[float]] = None, |
| image_std: Optional[list[float]] = None, |
| do_convert_rgb: Optional[bool] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Preprocess images by tiling them into grids. |
| |
| Args: |
| images: Single image or batch of images |
| |
| Returns: |
| BatchFeature with: |
| - pixel_values: List of [num_tiles, C, H, W] arrays (one per image) |
| - tile_grid_shape: List of (grid_h, grid_w) tuples |
| """ |
| tile_size = tile_size if tile_size is not None else self.tile_size |
| max_tiles_h = max_tiles_h if max_tiles_h is not None else self.max_tiles_h |
| max_tiles_w = max_tiles_w if max_tiles_w is not None else self.max_tiles_w |
| min_tiles = min_tiles if min_tiles is not None else self.min_tiles |
| do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
| rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor |
| do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
| image_mean = image_mean if image_mean is not None else self.image_mean |
| image_std = image_std if image_std is not None else self.image_std |
| do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb |
| |
| images = make_flat_list_of_images(images) |
| |
| if not valid_images(images): |
| raise ValueError("Invalid image input") |
| |
| all_pixel_values = [] |
| all_grid_shapes = [] |
| |
| for image in images: |
| |
| image = to_numpy_array(image) |
| |
| |
| if do_convert_rgb and image.shape[-1] == 4: |
| image = image[..., :3] |
| |
| |
| if image.ndim == 3: |
| if image.shape[0] in [1, 3, 4]: |
| h, w = image.shape[1], image.shape[2] |
| else: |
| h, w = image.shape[0], image.shape[1] |
| else: |
| raise ValueError(f"Expected 3D image, got shape {image.shape}") |
| |
| |
| grid_h, grid_w = calculate_tile_grid( |
| h, w, tile_size, max_tiles_h, max_tiles_w, min_tiles |
| ) |
| |
| |
| tiles = tile_image( |
| image, tile_size, grid_h, grid_w, resample=self.resample |
| ) |
| |
| |
| if do_rescale: |
| tiles = tiles.astype(np.float32) * rescale_factor |
| |
| |
| if do_normalize: |
| mean = np.array(image_mean, dtype=np.float32).reshape(1, 3, 1, 1) |
| std = np.array(image_std, dtype=np.float32).reshape(1, 3, 1, 1) |
| tiles = (tiles - mean) / std |
| |
| all_pixel_values.append(tiles) |
| all_grid_shapes.append((grid_h, grid_w)) |
| |
| |
| num_crops = [0] * len(all_pixel_values) |
| |
| |
| |
| if len(all_pixel_values) > 0: |
| concatenated_pixels = np.concatenate(all_pixel_values, axis=0) |
| else: |
| concatenated_pixels = np.array([]) |
| |
| data = { |
| "pixel_values": concatenated_pixels, |
| "tile_grid_shape": all_grid_shapes, |
| "num_crops": num_crops, |
| } |
| |
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
|
| __all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"] |
|
|