siddharth-magesh commited on
Commit
2aeeae9
·
verified ·
1 Parent(s): 72b6ebb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +420 -3
README.md CHANGED
@@ -1,3 +1,420 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - clip
7
+ - vision-language
8
+ - contrastive-learning
9
+ - image-text-matching
10
+ - pytorch
11
+ - vision-transformer
12
+ - zero-shot
13
+ - multimodal
14
+ - feature-extraction
15
+ library_name: pytorch
16
+ datasets:
17
+ - flickr30k
18
+ metrics:
19
+ - loss
20
+ pipeline_tag: feature-extraction
21
+ ---
22
+
23
+ # CLIP-Flickr30k: Contrastive Language-Image Pretraining Model
24
+
25
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
26
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
27
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
28
+
29
+ This repository contains PyTorch model weights for a CLIP (Contrastive Language-Image Pretraining) implementation trained from scratch on the Flickr30k dataset.
30
+
31
+ ## Model Overview
32
+
33
+ This is a **custom PyTorch implementation** of CLIP, not compatible with Hugging Face Transformers. The model learns to align images and text in a shared embedding space using contrastive learning.
34
+
35
+ ### Architecture
36
+
37
+ - **Vision Encoder**: Vision Transformer (ViT)
38
+ - Embedding dimension: 768
39
+ - Depth: 12 layers
40
+ - Attention heads: 12
41
+ - Patch size: 16×16
42
+ - Input size: 224×224
43
+
44
+ - **Text Encoder**: Transformer
45
+ - Embedding dimension: 512
46
+ - Depth: 8 layers
47
+ - Attention heads: 8
48
+ - Max sequence length: 77 tokens
49
+ - Vocabulary size: 49,408
50
+
51
+ - **Output**: 512-dimensional embeddings (both image and text)
52
+
53
+ ### Training Details
54
+
55
+ - **Dataset**: Flickr30k (1,000 image-caption pairs, 200 unique images)
56
+ - **Epochs**: 50
57
+ - **Batch Size**: 64
58
+ - **Optimizer**: Adam (lr=1e-4, weight_decay=1e-4)
59
+ - **Scheduler**: CosineAnnealingLR
60
+ - **Temperature**: 0.07
61
+ - **Device**: CUDA (GPU)
62
+ - **Training Time**: 8.12 hours
63
+
64
+ ## Performance
65
+
66
+ | Metric | Value |
67
+ |--------|-------|
68
+ | Best Loss | **0.2570** (epoch 44) |
69
+ | Initial Loss | 4.3295 |
70
+ | Loss Reduction | 93.8% |
71
+ | Convergence | Epoch 35-40 |
72
+
73
+ ### Training Progress
74
+
75
+ ```
76
+ Epoch 1: Loss = 4.3295
77
+ Epoch 10: Loss = 3.3269
78
+ Epoch 20: Loss = 0.7544
79
+ Epoch 30: Loss = 0.3712
80
+ Epoch 44: Loss = 0.2570 (Best)
81
+ Epoch 50: Loss = 0.2683
82
+ ```
83
+
84
+ ## Model Files
85
+
86
+ This repository contains:
87
+
88
+ - `best_model.pth` - Best performing checkpoint (epoch 44, loss: 0.2570) - **598 MB**
89
+ - Additional epoch checkpoints (epochs 1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50)
90
+
91
+ **Note:** These are raw PyTorch state dictionaries, not Hugging Face Transformers models.
92
+
93
+ ## Usage
94
+
95
+ ### Installation
96
+
97
+ ```bash
98
+ pip install torch torchvision pandas numpy pillow
99
+ ```
100
+
101
+ ### Model Architecture Code
102
+
103
+ You need to implement the model architecture to load these weights. Here's the required structure:
104
+
105
+ ```python
106
+ import torch
107
+ import torch.nn as nn
108
+ import torch.nn.functional as F
109
+
110
+ class CLIP(nn.Module):
111
+ def __init__(self):
112
+ super().__init__()
113
+ # Vision Transformer
114
+ self.visual = VisionTransformer(
115
+ img_size=224,
116
+ patch_size=16,
117
+ embed_dim=768,
118
+ depth=12,
119
+ num_heads=12,
120
+ output_dim=512
121
+ )
122
+ # Text Transformer
123
+ self.text = TextTransformer(
124
+ vocab_size=49408,
125
+ embed_dim=512,
126
+ max_len=77,
127
+ num_heads=8,
128
+ depth=8,
129
+ output_dim=512
130
+ )
131
+ self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
132
+
133
+ def encode_image(self, image):
134
+ image_features = self.visual(image)
135
+ return F.normalize(image_features, dim=-1)
136
+
137
+ def encode_text(self, text):
138
+ text_features = self.text(text)
139
+ return F.normalize(text_features, dim=-1)
140
+
141
+ def forward(self, image, text):
142
+ image_features = self.encode_image(image)
143
+ text_features = self.encode_text(text)
144
+ logits = image_features @ text_features.T * torch.exp(self.temperature)
145
+ return logits, image_features, text_features
146
+ ```
147
+
148
+ ### Loading the Model
149
+
150
+ ```python
151
+ from huggingface_hub import hf_hub_download
152
+ import torch
153
+
154
+ # Download the best model checkpoint
155
+ model_path = hf_hub_download(
156
+ repo_id="siddharth-magesh/clip-flickr30k",
157
+ filename="best_model.pth"
158
+ )
159
+
160
+ # Initialize your model (requires architecture implementation)
161
+ model = CLIP()
162
+
163
+ # Load weights
164
+ checkpoint = torch.load(model_path, map_location='cpu')
165
+ model.load_state_dict(checkpoint)
166
+ model.eval()
167
+
168
+ print("Model loaded successfully!")
169
+ ```
170
+
171
+ ### Inference Example
172
+
173
+ ```python
174
+ import torch
175
+ from torchvision import transforms
176
+ from PIL import Image
177
+
178
+ # Image preprocessing
179
+ transform = transforms.Compose([
180
+ transforms.Resize((224, 224)),
181
+ transforms.ToTensor(),
182
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
183
+ ])
184
+
185
+ # Load and preprocess image
186
+ image = Image.open('your_image.jpg').convert('RGB')
187
+ image_tensor = transform(image).unsqueeze(0)
188
+
189
+ # Simple tokenizer (hash-based)
190
+ def tokenize(text, max_length=77):
191
+ import numpy as np
192
+ tokens = text.lower().split()
193
+ idxs = [min(hash(w) % 49408, 49407) for w in tokens][:max_length]
194
+ arr = np.zeros(max_length, dtype=np.int64)
195
+ arr[:len(idxs)] = idxs
196
+ return torch.tensor(arr, dtype=torch.long)
197
+
198
+ # Tokenize text
199
+ text = "a photo of a dog"
200
+ text_tensor = tokenize(text).unsqueeze(0)
201
+
202
+ # Inference
203
+ with torch.no_grad():
204
+ image_features = model.encode_image(image_tensor)
205
+ text_features = model.encode_text(text_tensor)
206
+
207
+ # Compute similarity
208
+ similarity = (image_features @ text_features.T).item()
209
+ print(f"Similarity: {similarity:.4f}")
210
+ ```
211
+
212
+ ### Zero-Shot Image Classification
213
+
214
+ ```python
215
+ def zero_shot_classification(image, texts, model):
216
+ """
217
+ Classify an image using text descriptions.
218
+
219
+ Args:
220
+ image: PIL Image
221
+ texts: List of text descriptions
222
+ model: CLIP model
223
+
224
+ Returns:
225
+ Probabilities for each text
226
+ """
227
+ # Preprocess image
228
+ image_tensor = transform(image).unsqueeze(0)
229
+
230
+ # Tokenize all texts
231
+ text_tensors = torch.stack([tokenize(text) for text in texts])
232
+
233
+ with torch.no_grad():
234
+ image_features = model.encode_image(image_tensor)
235
+ text_features = model.encode_text(text_tensors)
236
+
237
+ # Compute similarities
238
+ similarities = image_features @ text_features.T
239
+ probs = F.softmax(similarities / 0.07, dim=-1)
240
+
241
+ return probs[0].numpy()
242
+
243
+ # Example usage
244
+ texts = [
245
+ "a photo of a dog",
246
+ "a photo of a cat",
247
+ "a photo of a bird"
248
+ ]
249
+ probs = zero_shot_classification(image, texts, model)
250
+
251
+ for text, prob in zip(texts, probs):
252
+ print(f"{text}: {prob:.2%}")
253
+ ```
254
+
255
+ ### Image-Text Retrieval
256
+
257
+ ```python
258
+ def retrieve_images(query_text, image_paths, model, top_k=5):
259
+ """
260
+ Retrieve most relevant images for a text query.
261
+
262
+ Args:
263
+ query_text: Text query
264
+ image_paths: List of image file paths
265
+ model: CLIP model
266
+ top_k: Number of results to return
267
+
268
+ Returns:
269
+ List of (image_path, similarity) tuples
270
+ """
271
+ # Encode query
272
+ query_tensor = tokenize(query_text).unsqueeze(0)
273
+ with torch.no_grad():
274
+ query_features = model.encode_text(query_tensor)
275
+
276
+ # Encode images
277
+ similarities = []
278
+ for img_path in image_paths:
279
+ image = Image.open(img_path).convert('RGB')
280
+ image_tensor = transform(image).unsqueeze(0)
281
+
282
+ with torch.no_grad():
283
+ image_features = model.encode_image(image_tensor)
284
+ sim = (query_features @ image_features.T).item()
285
+
286
+ similarities.append((img_path, sim))
287
+
288
+ # Sort by similarity
289
+ similarities.sort(key=lambda x: x[1], reverse=True)
290
+ return similarities[:top_k]
291
+ ```
292
+
293
+ ## Full Implementation
294
+
295
+ For the complete implementation including all architecture components, visit:
296
+ - **GitHub Repository**: [Include your GitHub link here]
297
+ - **Documentation**: Comprehensive docs available in the repository
298
+
299
+ Required files for full implementation:
300
+ - `clip.py` - Main CLIP model
301
+ - `vision_transformer.py` - Vision encoder
302
+ - `text_transformer.py` - Text encoder
303
+ - `modules/transformer.py` - Transformer blocks
304
+ - `modules/multi_head_attention.py` - Attention mechanism
305
+ - `modules/multi_layer_perceptron.py` - MLP layers
306
+ - `modules/patch_embedding.py` - Patch embedding
307
+
308
+ ## Important Notes
309
+
310
+ 1. **Not Hugging Face Transformers Compatible**: This model uses custom PyTorch code, not the Transformers library.
311
+
312
+ 2. **Architecture Required**: You must implement the model architecture (see structure above) to use these weights.
313
+
314
+ 3. **Simple Tokenizer**: Uses hash-based tokenization (not WordPiece or BPE).
315
+
316
+ 4. **Limited Dataset**: Trained on only 1,000 image-caption pairs. For production use, retrain on the full Flickr30k dataset (158,925 pairs) or larger datasets like COCO.
317
+
318
+ 5. **GPU Recommended**: Inference is faster on GPU, but CPU works fine.
319
+
320
+ ## 🔧 Model Configuration
321
+
322
+ ```python
323
+ config = {
324
+ # Vision Transformer
325
+ 'img_size': 224,
326
+ 'patch_size': 16,
327
+ 'vision_embed_dim': 768,
328
+ 'vision_depth': 12,
329
+ 'vision_heads': 12,
330
+ 'vision_dropout': 0.1,
331
+
332
+ # Text Transformer
333
+ 'vocab_size': 49408,
334
+ 'text_embed_dim': 512,
335
+ 'max_len': 77,
336
+ 'text_heads': 8,
337
+ 'text_depth': 8,
338
+ 'text_dropout': 0.1,
339
+
340
+ # Common
341
+ 'output_dim': 512,
342
+ 'temperature': 0.07,
343
+ }
344
+ ```
345
+
346
+ ## Training Details
347
+
348
+ ### Loss Function
349
+
350
+ Symmetric contrastive loss:
351
+ ```python
352
+ loss = (cross_entropy(image_to_text_logits, labels) +
353
+ cross_entropy(text_to_image_logits, labels)) / 2
354
+ ```
355
+
356
+ ### Data Augmentation
357
+
358
+ Standard ImageNet normalization:
359
+ - Mean: [0.485, 0.456, 0.406]
360
+ - Std: [0.229, 0.224, 0.225]
361
+
362
+ ### Hardware
363
+
364
+ - GPU: CUDA-enabled GPU
365
+ - Training time: ~580 seconds per epoch
366
+ - Total training: 8.12 hours (50 epochs)
367
+
368
+ ## Citation
369
+
370
+ If you use this model, please cite:
371
+
372
+ ```bibtex
373
+ @misc{clip-flickr30k-2025,
374
+ author = {Siddharth Magesh},
375
+ title = {CLIP-Flickr30k: PyTorch Implementation},
376
+ year = {2025},
377
+ publisher = {HuggingFace Hub},
378
+ url = {https://huggingface.co/siddharth-magesh/clip-flickr30k}
379
+ }
380
+ ```
381
+
382
+ Original CLIP paper:
383
+ ```bibtex
384
+ @inproceedings{radford2021learning,
385
+ title={Learning Transferable Visual Models From Natural Language Supervision},
386
+ author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and others},
387
+ booktitle={International Conference on Machine Learning},
388
+ year={2021}
389
+ }
390
+ ```
391
+
392
+ ## License
393
+
394
+ MIT License - See LICENSE file for details.
395
+
396
+ ## Links
397
+
398
+ - **Model Card**: [Hugging Face Model Hub](https://huggingface.co/siddharth-magesh/clip-flickr30k)
399
+ - **Dataset**: [Flickr30k on Kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)
400
+ - **Original CLIP Paper**: [arXiv:2103.00020](https://arxiv.org/abs/2103.00020)
401
+
402
+ ## Contact
403
+
404
+ For questions or issues:
405
+ - Create an issue on GitHub
406
+ - Discussion tab on Hugging Face
407
+
408
+ ## Acknowledgments
409
+
410
+ - OpenAI for the original CLIP architecture
411
+ - Flickr30k dataset creators
412
+ - PyTorch team
413
+
414
+ ---
415
+
416
+ **Note**: This is an educational implementation. For production use, consider:
417
+ 1. Training on larger datasets (COCO, Conceptual Captions, LAION)
418
+ 2. Using proper tokenizers (BPE, WordPiece)
419
+ 3. Pre-training on web-scale data
420
+ 4. Fine-tuning for specific tasks