cooper0914 commited on
Commit
31385aa
Β·
verified Β·
1 Parent(s): 16a4dc8

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +351 -0
  2. photo_to_sketch.pth +3 -0
  3. sketch_to_photo.pth +3 -0
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # from PIL import Image
3
+ # import torch
4
+ # from torchvision import transforms
5
+ # from io import BytesIO
6
+
7
+ # # --------------------------
8
+ # # βš™οΈ Streamlit Page Config
9
+ # # --------------------------
10
+ # st.set_page_config(page_title="CycleGAN Image Translator 🎨", layout="wide", page_icon="🎭")
11
+
12
+ # st.markdown("""
13
+ # <style>
14
+ # body {
15
+ # background-color: #0E1117;
16
+ # color: white;
17
+ # }
18
+ # .stButton>button {
19
+ # background-color: #059bdd;
20
+ # color: white;
21
+ # border-radius: 10px;
22
+ # padding: 0.5em 1em;
23
+ # font-size: 1.1em;
24
+ # }
25
+ # </style>
26
+ # """, unsafe_allow_html=True)
27
+
28
+ # st.title("🎨 CycleGAN Image Translator")
29
+ # st.markdown("Convert between **Sketch ↔ Real Image** using your trained model.")
30
+
31
+ # # --------------------------
32
+ # # 🧠 Load Model
33
+ # # --------------------------
34
+ # @st.cache_resource
35
+ # def load_model():
36
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ # model = torch.load("cyclegan_model.pth", map_location=device)
38
+ # model.eval()
39
+ # return model, device
40
+
41
+ # model, device = load_model()
42
+
43
+ # # --------------------------
44
+ # # πŸ–Ό Image Processing Utils
45
+ # # --------------------------
46
+ # transform = transforms.Compose([
47
+ # transforms.Resize((256, 256)),
48
+ # transforms.ToTensor(),
49
+ # transforms.Normalize((0.5,), (0.5,))
50
+ # ])
51
+
52
+ # def tensor_to_image(tensor):
53
+ # tensor = tensor.squeeze(0).detach().cpu()
54
+ # tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
55
+ # return transforms.ToPILImage()(tensor)
56
+
57
+ # # --------------------------
58
+ # # πŸš€ UI Workflow
59
+ # # --------------------------
60
+ # uploaded_file = st.file_uploader("Upload an image (JPG or PNG)", type=["jpg", "jpeg", "png"])
61
+
62
+ # if uploaded_file:
63
+ # input_image = Image.open(uploaded_file).convert("RGB")
64
+ # st.image(input_image, caption="Uploaded Image", use_container_width=True)
65
+
66
+ # if st.button("✨ Generate"):
67
+ # with st.spinner("Running the model... please wait ⏳"):
68
+ # input_tensor = transform(input_image).unsqueeze(0).to(device)
69
+ # with torch.no_grad():
70
+ # output = model(input_tensor)
71
+ # output_image = tensor_to_image(output)
72
+
73
+ # st.image(output_image, caption="Generated Output", use_container_width=True)
74
+
75
+ # # Option to download
76
+ # buf = BytesIO()
77
+ # output_image.save(buf, format="JPEG")
78
+ # byte_im = buf.getvalue()
79
+ # st.download_button("πŸ“₯ Download Result", data=byte_im, file_name="output.jpg", mime="image/jpeg")
80
+
81
+ import streamlit as st
82
+ import torch
83
+ import torch.nn as nn
84
+ from PIL import Image
85
+ import numpy as np
86
+ from torchvision import transforms
87
+ import io
88
+
89
+ # Set page config
90
+ st.set_page_config(
91
+ page_title="Face ↔ Sketch CycleGAN",
92
+ page_icon="🎨",
93
+ layout="wide"
94
+ )
95
+
96
+ # Generator Architecture (same as training)
97
+ class ResidualBlock(nn.Module):
98
+ def __init__(self, in_channels):
99
+ super(ResidualBlock, self).__init__()
100
+ self.block = nn.Sequential(
101
+ nn.ReflectionPad2d(1),
102
+ nn.Conv2d(in_channels, in_channels, kernel_size=3),
103
+ nn.InstanceNorm2d(in_channels),
104
+ nn.ReLU(inplace=True),
105
+ nn.ReflectionPad2d(1),
106
+ nn.Conv2d(in_channels, in_channels, kernel_size=3),
107
+ nn.InstanceNorm2d(in_channels)
108
+ )
109
+
110
+ def forward(self, x):
111
+ return x + self.block(x)
112
+
113
+
114
+ class Generator(nn.Module):
115
+ def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9):
116
+ super(Generator, self).__init__()
117
+
118
+ model = [
119
+ nn.ReflectionPad2d(3),
120
+ nn.Conv2d(input_channels, 64, kernel_size=7),
121
+ nn.InstanceNorm2d(64),
122
+ nn.ReLU(inplace=True)
123
+ ]
124
+
125
+ in_channels = 64
126
+ out_channels = in_channels * 2
127
+ for _ in range(2):
128
+ model += [
129
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
130
+ nn.InstanceNorm2d(out_channels),
131
+ nn.ReLU(inplace=True)
132
+ ]
133
+ in_channels = out_channels
134
+ out_channels = in_channels * 2
135
+
136
+ for _ in range(num_residual_blocks):
137
+ model += [ResidualBlock(in_channels)]
138
+
139
+ out_channels = in_channels // 2
140
+ for _ in range(2):
141
+ model += [
142
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2,
143
+ padding=1, output_padding=1),
144
+ nn.InstanceNorm2d(out_channels),
145
+ nn.ReLU(inplace=True)
146
+ ]
147
+ in_channels = out_channels
148
+ out_channels = in_channels // 2
149
+
150
+ model += [
151
+ nn.ReflectionPad2d(3),
152
+ nn.Conv2d(64, output_channels, kernel_size=7),
153
+ nn.Tanh()
154
+ ]
155
+
156
+ self.model = nn.Sequential(*model)
157
+
158
+ def forward(self, x):
159
+ return self.model(x)
160
+
161
+
162
+ # Cache models to avoid reloading
163
+ @st.cache_resource
164
+ def load_models():
165
+ """Load both generator models"""
166
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
167
+
168
+ # Load Photo β†’ Sketch model
169
+ G_AB = Generator().to(device)
170
+ checkpoint_ab = torch.load('photo_to_sketch.pth', map_location=device)
171
+ G_AB.load_state_dict(checkpoint_ab['model_state_dict'])
172
+ G_AB.eval()
173
+
174
+ # Load Sketch β†’ Photo model
175
+ G_BA = Generator().to(device)
176
+ checkpoint_ba = torch.load('sketch_to_photo.pth', map_location=device)
177
+ G_BA.load_state_dict(checkpoint_ba['model_state_dict'])
178
+ G_BA.eval()
179
+
180
+ return G_AB, G_BA, device
181
+
182
+
183
+ def preprocess_image(image, target_size=256):
184
+ """Preprocess image for model input"""
185
+ transform = transforms.Compose([
186
+ transforms.Resize((target_size, target_size)),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
189
+ ])
190
+
191
+ image = image.convert('RGB')
192
+ return transform(image).unsqueeze(0)
193
+
194
+
195
+ def postprocess_image(tensor):
196
+ """Convert model output back to PIL Image"""
197
+ image = tensor.cpu().squeeze().detach().numpy()
198
+ image = image.transpose(1, 2, 0)
199
+ image = (image * 0.5 + 0.5).clip(0, 1) # Denormalize
200
+ image = (image * 255).astype(np.uint8)
201
+ return Image.fromarray(image)
202
+
203
+
204
+ def detect_image_type(image):
205
+ """
206
+ Simple heuristic to detect if image is a sketch or photo
207
+ Sketches typically have higher contrast and less color variation
208
+ """
209
+ img_array = np.array(image.convert('L'))
210
+
211
+ # Calculate statistics
212
+ std_dev = np.std(img_array)
213
+ mean_val = np.mean(img_array)
214
+
215
+ # Sketches tend to have higher std deviation and be closer to extremes
216
+ if std_dev > 80 and (mean_val > 180 or mean_val < 100):
217
+ return "sketch"
218
+ else:
219
+ return "photo"
220
+
221
+
222
+ def convert_image(image, model, device):
223
+ """Convert image using the specified model"""
224
+ input_tensor = preprocess_image(image).to(device)
225
+
226
+ with torch.no_grad():
227
+ output_tensor = model(input_tensor)
228
+
229
+ return postprocess_image(output_tensor)
230
+
231
+
232
+ # Main App
233
+ def main():
234
+ st.title("🎨 Face ↔ Sketch CycleGAN")
235
+ st.markdown("Convert photos to sketches and sketches to photos using CycleGAN")
236
+
237
+ # Load models
238
+ try:
239
+ G_AB, G_BA, device = load_models()
240
+ st.success(f"βœ… Models loaded successfully! Using: {device}")
241
+ except Exception as e:
242
+ st.error(f"❌ Error loading models: {str(e)}")
243
+ st.stop()
244
+
245
+ # Sidebar
246
+ st.sidebar.header("βš™οΈ Settings")
247
+ conversion_mode = st.sidebar.radio(
248
+ "Conversion Mode",
249
+ ["Auto-detect", "Photo β†’ Sketch", "Sketch β†’ Photo"],
250
+ help="Auto-detect will automatically determine the input type"
251
+ )
252
+
253
+ # Main content
254
+ col1, col2 = st.columns(2)
255
+
256
+ with col1:
257
+ st.header("πŸ“€ Input")
258
+ upload_method = st.radio("Upload method:", ["Upload Image", "Use Camera"])
259
+
260
+ if upload_method == "Upload Image":
261
+ uploaded_file = st.file_uploader(
262
+ "Choose an image...",
263
+ type=['png', 'jpg', 'jpeg'],
264
+ help="Upload a photo or sketch"
265
+ )
266
+
267
+ if uploaded_file is not None:
268
+ input_image = Image.open(uploaded_file)
269
+ st.image(input_image, caption="Input Image", use_container_width=True)
270
+ else:
271
+ camera_photo = st.camera_input("Take a picture")
272
+ if camera_photo is not None:
273
+ input_image = Image.open(camera_photo)
274
+ st.image(input_image, caption="Captured Image", use_container_width=True)
275
+ else:
276
+ input_image = None
277
+
278
+ with col2:
279
+ st.header("πŸ“₯ Output")
280
+
281
+ if 'input_image' in locals() and input_image is not None:
282
+ # Determine conversion direction
283
+ if conversion_mode == "Auto-detect":
284
+ detected_type = detect_image_type(input_image)
285
+ st.info(f"πŸ” Detected: {detected_type.upper()}")
286
+
287
+ if detected_type == "photo":
288
+ output_image = convert_image(input_image, G_AB, device)
289
+ conversion_text = "Photo β†’ Sketch"
290
+ else:
291
+ output_image = convert_image(input_image, G_BA, device)
292
+ conversion_text = "Sketch β†’ Photo"
293
+
294
+ elif conversion_mode == "Photo β†’ Sketch":
295
+ output_image = convert_image(input_image, G_AB, device)
296
+ conversion_text = "Photo β†’ Sketch"
297
+
298
+ else: # Sketch β†’ Photo
299
+ output_image = convert_image(input_image, G_BA, device)
300
+ conversion_text = "Sketch β†’ Photo"
301
+
302
+ st.image(output_image, caption=f"Output ({conversion_text})", use_container_width=True)
303
+
304
+ # Download button
305
+ buf = io.BytesIO()
306
+ output_image.save(buf, format="PNG")
307
+ byte_im = buf.getvalue()
308
+
309
+ st.download_button(
310
+ label="⬇️ Download Result",
311
+ data=byte_im,
312
+ file_name=f"cyclegan_output_{conversion_text.replace(' β†’ ', '_to_')}.png",
313
+ mime="image/png"
314
+ )
315
+ else:
316
+ st.info("πŸ‘† Upload or capture an image to see the conversion")
317
+
318
+ # Information section
319
+ with st.expander("ℹ️ About this app"):
320
+ st.markdown("""
321
+ ### CycleGAN Face-Sketch Converter
322
+
323
+ This application uses CycleGAN (Cycle-Consistent Generative Adversarial Networks)
324
+ to convert between face photos and sketches.
325
+
326
+ **Features:**
327
+ - 🎨 Photo to Sketch conversion
328
+ - πŸ–ΌοΈ Sketch to Photo conversion
329
+ - πŸ” Automatic input type detection
330
+ - πŸ“Έ Camera support
331
+
332
+ **How it works:**
333
+ CycleGAN learns to translate images between two domains without paired examples.
334
+ It uses cycle consistency loss to ensure the translation is meaningful.
335
+
336
+ **Model Details:**
337
+ - Architecture: ResNet-based Generator
338
+ - Training: Unpaired face-sketch dataset
339
+ - Image size: 256x256 pixels
340
+ """)
341
+
342
+ # Footer
343
+ st.markdown("---")
344
+ st.markdown(
345
+ "<div style='text-align: center'>Made with ❀️ using Streamlit and PyTorch</div>",
346
+ unsafe_allow_html=True
347
+ )
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
photo_to_sketch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b4123d8d4a392e696f9236e56e28b6b0ad20feb39c84ae01d4288f068bba10c
3
+ size 45532419
sketch_to_photo.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af23efe5b5390c3682ad37e1a3c5fce1203fb1fcb092728fbd378a8409c6be87
3
+ size 45532419