File size: 16,368 Bytes
8f80642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a969b
8f80642
 
 
 
 
 
 
15a969b
 
8f80642
 
 
 
 
 
 
 
15a969b
8f80642
 
 
 
 
 
 
 
 
15a969b
 
8f80642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a969b
8f80642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# -------------------------------------------------------------------
# This source file is available under the terms of the
# Pimcore Open Core License (POCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
#  @copyright  Copyright (c) Pimcore GmbH (https://www.pimcore.com)
#  @license    Pimcore Open Core License (POCL)
# -------------------------------------------------------------------

import torch
import base64
import io
import logging
from PIL import Image
from pydantic import BaseModel
from fastapi import Request, HTTPException
import json
from typing import Optional, Union, Dict, Any
from transformers import AutoProcessor, AutoModel


class EmbeddingRequest(BaseModel):
    inputs: str
    parameters: Optional[dict] = None


class BaseEmbeddingTaskService:
    """Base class for embedding services with common functionality"""
    
    def __init__(self, logger: logging.Logger):
        self._logger = logger
        self._model_cache = {}
        self._processor_cache = {}

    async def get_embedding_request(self, request: Request) -> EmbeddingRequest:
        """Parse request body into EmbeddingRequest"""
        content_type = request.headers.get("content-type", "")
        if content_type.startswith("application/json"):
            data = await request.json()
            return EmbeddingRequest(**data)
        if content_type.startswith("application/x-www-form-urlencoded"):
            raw = await request.body()
            try:
                data = json.loads(raw)
                return EmbeddingRequest(**data)
            except Exception:
                try:
                    data = json.loads(raw.decode("utf-8"))
                    return EmbeddingRequest(**data)
                except Exception:
                    raise HTTPException(status_code=400, detail="Invalid request body")
        raise HTTPException(status_code=400, detail="Unsupported content type")

    def _get_device(self) -> torch.device:
        """Get the appropriate device (GPU if available, otherwise CPU)"""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._logger.info(f"Using device: {device}")
        return device

    def _load_processor(self, model_name: str):
        """Load and cache processor for the model using AutoProcessor"""
        if model_name not in self._processor_cache:
            try:
                self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
                self._logger.info(f"Loaded processor for model: {model_name}")
            except Exception as e:
                self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
                raise HTTPException(
                    status_code=404,
                    detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
                )
        else:
            self._logger.info(f"Using cached processor for model: {model_name}")
        return self._processor_cache[model_name]

    def _load_model(self, model_name: str, cache_suffix: str = ""):
        """Load and cache model using AutoModel"""
        cache_key = f"{model_name}{cache_suffix}"
        if cache_key not in self._model_cache:
            try:
                device = self._get_device()
                model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
                model.to(device)
                self._model_cache[cache_key] = model
                self._logger.info(f"Loaded model: {model_name} on {device}")
            except Exception as e:
                self._logger.error(f"Failed to load model '{model_name}': {str(e)}")
                raise HTTPException(
                    status_code=404,
                    detail=f"Model '{model_name}' could not be loaded: {str(e)}"
                )
        else:
            self._logger.info(f"Using cached model: {model_name} (cache key: {cache_key})")
        return self._model_cache[cache_key]

    async def get_embedding_vector_size(self, model_name: str) -> dict:
        """Get the vector size of embeddings for a given model"""
        try:
            # Load the model to get its configuration
            model = self._load_model(model_name)
            
            # Try to get the embedding dimension from the model configuration
            used_attribute = None
            if hasattr(model.config, 'hidden_size'):
                vector_size = model.config.hidden_size
                used_attribute = "hidden_size"
            elif hasattr(model.config, 'projection_dim'):
                vector_size = model.config.projection_dim
                used_attribute = "projection_dim"
            elif hasattr(model.config, 'd_model'):
                vector_size = model.config.d_model
                used_attribute = "d_model"
            elif hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'hidden_size'):
                vector_size = model.config.text_config.hidden_size
                used_attribute = "text_config.hidden_size"
            elif hasattr(model.config, 'vision_config') and hasattr(model.config.vision_config, 'hidden_size'):
                vector_size = model.config.vision_config.hidden_size
                used_attribute = "vision_config.hidden_size"
            else:
                # If we can't determine from config, we'll need to run a dummy inference
                raise AttributeError("Could not determine vector size from model configuration")
            
            self._logger.info(f"Model {model_name} has embedding vector size: {vector_size}")
            return {
                "model_name": model_name,
                "vector_size": vector_size,
                "config_attribute_used": used_attribute
            }
            
        except Exception as e:
            self._logger.error(f"Failed to get vector size for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=404,
                detail=f"Could not determine vector size for model '{model_name}': {str(e)}"
            )

    def _extract_embeddings(self, model_output, model_name: str) -> torch.Tensor:
        """Extract embeddings from model output with fallback strategies"""
        
        # Try different embedding extraction methods in order of preference
        
        # 1. Check for pooler_output (most common)
        if hasattr(model_output, 'pooler_output') and model_output.pooler_output is not None:
            self._logger.debug(f"Using pooler_output for {model_name}")
            return model_output.pooler_output
        
        # 2. Check for last_hidden_state and pool it
        if hasattr(model_output, 'last_hidden_state') and model_output.last_hidden_state is not None:
            self._logger.debug(f"Using pooled last_hidden_state for {model_name}")
            # Mean pooling over sequence dimension
            return model_output.last_hidden_state.mean(dim=1)
        
        # 3. Check for image_embeds (CLIP-style models)
        if hasattr(model_output, 'image_embeds') and model_output.image_embeds is not None:
            self._logger.debug(f"Using image_embeds for {model_name}")
            return model_output.image_embeds
            
        # 4. Check for text_embeds (CLIP-style models)
        if hasattr(model_output, 'text_embeds') and model_output.text_embeds is not None:
            self._logger.debug(f"Using text_embeds for {model_name}")
            return model_output.text_embeds
        
        # 5. Fallback: try to use the output directly if it's a tensor
        if isinstance(model_output, torch.Tensor):
            self._logger.debug(f"Using direct tensor output for {model_name}")
            return model_output
        
        # 6. Last resort: check if output is a tuple and use the first element
        if isinstance(model_output, tuple) and len(model_output) > 0:
            self._logger.debug(f"Using first element of tuple output for {model_name}")
            return model_output[0]
        
        # If none of the above work, raise an error
        raise HTTPException(
            status_code=500,
            detail=f"Could not extract embeddings from model output for {model_name}. "
                   f"Available attributes: {dir(model_output) if hasattr(model_output, '__dict__') else 'Unknown'}"
        )


class ImageEmbeddingTaskService(BaseEmbeddingTaskService):
    """Service for generating image embeddings"""

    def _decode_base64_image(self, base64_string: str) -> Image.Image:
        """Decode base64 string to PIL Image"""
        try:
            # Remove data URL prefix if present
            if base64_string.startswith('data:image'):
                base64_string = base64_string.split(',')[1]
            
            image_data = base64.b64decode(base64_string)
            image = Image.open(io.BytesIO(image_data))
            
            # Convert to RGB if necessary
            if image.mode != 'RGB':
                image = image.convert('RGB')
                
            return image
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")

    def _generate_image_embeddings(self, image: Image.Image, model, processor, model_name: str) -> list:
        """Generate embeddings for an image"""
        device = self._get_device()
        
        # Process the image
        inputs = processor(images=image, return_tensors="pt", padding=True)
        
        # Move inputs to the same device as the model
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the embeddings
        with torch.no_grad():
            # Try using specialized methods first for CLIP-like models
            if hasattr(model, 'get_image_features'):
                self._logger.debug(f"Using get_image_features for {model_name}")
                embeddings = model.get_image_features(pixel_values=inputs.get('pixel_values'))
            elif hasattr(model, 'vision_model'):
                self._logger.debug(f"Using vision_model for {model_name}")
                vision_outputs = model.vision_model(**inputs)
                embeddings = self._extract_embeddings(vision_outputs, model_name)
            else:
                self._logger.debug(f"Using full model for {model_name}")
                outputs = model(**inputs)
                embeddings = self._extract_embeddings(outputs, model_name)
        
        self._logger.info(f"Image embedding shape: {embeddings.shape}")
        
        # Move back to CPU before converting to numpy
        embeddings_array = embeddings.cpu().numpy()
        
        return embeddings_array[0].tolist()

    async def generate_embedding(self, request: Request, model_name: str):
        """Main method to generate image embeddings"""
        embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
        
        self._logger.info(f"Generating image embedding for model: {model_name}")
        
        # Load processor and model using auto-detection
        processor = self._load_processor(model_name)
        model = self._load_model(model_name, "_image")
        
        # Decode image from base64
        image = self._decode_base64_image(embedding_request.inputs)
        
        try:
            # Generate embeddings
            embeddings = self._generate_image_embeddings(image, model, processor, model_name)
            
            self._logger.info("Image embedding generation completed")
            return {"embeddings": embeddings}
            
        except Exception as e:
            self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Embedding generation failed: {str(e)}"
            )

    async def generate_embedding_from_upload(self, uploaded_file, model_name: str):
        """Generate image embeddings from uploaded file"""
        from fastapi import UploadFile
        
        self._logger.info(f"Generating image embedding from uploaded file for model: {model_name}")
        
        # Validate file type
        if not uploaded_file.content_type.startswith('image/'):
            raise HTTPException(
                status_code=400,
                detail=f"Invalid file type: {uploaded_file.content_type}. Only image files are supported."
            )
        
        try:
            # Read file content
            file_content = await uploaded_file.read()
            
            # Convert to PIL Image
            image = Image.open(io.BytesIO(file_content)).convert('RGB')
            
            # Load processor and model using auto-detection
            processor = self._load_processor(model_name)
            model = self._load_model(model_name, "_image")
            
            # Generate embeddings
            embeddings = self._generate_image_embeddings(image, model, processor, model_name)
            
            self._logger.info("Image embedding generation from upload completed")
            return {"embeddings": embeddings}
            
        except Exception as e:
            self._logger.error(f"Embedding generation from upload failed for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Embedding generation from upload failed: {str(e)}"
            )


class TextEmbeddingTaskService(BaseEmbeddingTaskService):
    """Service for generating text embeddings"""

    def _generate_text_embeddings(self, text: str, model, processor, model_name: str) -> list:
        """Generate embeddings for text"""
        device = self._get_device()
        
        # Process the text
        inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
        
        # Move inputs to the same device as the model
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the embeddings
        with torch.no_grad():
            # Try using specialized methods first for CLIP-like models
            if hasattr(model, 'get_text_features'):
                self._logger.debug(f"Using get_text_features for {model_name}")
                embeddings = model.get_text_features(
                    input_ids=inputs.get('input_ids'),
                    attention_mask=inputs.get('attention_mask')
                )
            elif hasattr(model, 'text_model'):
                self._logger.debug(f"Using text_model for {model_name}")
                text_outputs = model.text_model(**inputs)
                embeddings = self._extract_embeddings(text_outputs, model_name)
            else:
                self._logger.debug(f"Using full model for {model_name}")
                outputs = model(**inputs)
                embeddings = self._extract_embeddings(outputs, model_name)
        
        self._logger.info(f"Text embedding shape: {embeddings.shape}")
        
        # Move back to CPU before converting to numpy
        embeddings_array = embeddings.cpu().numpy()
        
        return embeddings_array[0].tolist()

    async def generate_embedding(self, request: Request, model_name: str):
        """Main method to generate text embeddings"""
        embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
        
        self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:500]}...")
        
        # Load processor and model using auto-detection
        processor = self._load_processor(model_name)
        model = self._load_model(model_name, "_text")
        
        try:
            # Generate embeddings
            embeddings = self._generate_text_embeddings(embedding_request.inputs, model, processor, model_name)
            
            self._logger.info("Text embedding generation completed")
            return {"embeddings": embeddings}
            
        except Exception as e:
            self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Embedding generation failed: {str(e)}"
            )