cooper0914 commited on
Commit
5119d93
Β·
verified Β·
1 Parent(s): 31385aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -350
app.py CHANGED
@@ -1,351 +1,271 @@
1
- # import streamlit as st
2
- # from PIL import Image
3
- # import torch
4
- # from torchvision import transforms
5
- # from io import BytesIO
6
-
7
- # # --------------------------
8
- # # βš™οΈ Streamlit Page Config
9
- # # --------------------------
10
- # st.set_page_config(page_title="CycleGAN Image Translator 🎨", layout="wide", page_icon="🎭")
11
-
12
- # st.markdown("""
13
- # <style>
14
- # body {
15
- # background-color: #0E1117;
16
- # color: white;
17
- # }
18
- # .stButton>button {
19
- # background-color: #059bdd;
20
- # color: white;
21
- # border-radius: 10px;
22
- # padding: 0.5em 1em;
23
- # font-size: 1.1em;
24
- # }
25
- # </style>
26
- # """, unsafe_allow_html=True)
27
-
28
- # st.title("🎨 CycleGAN Image Translator")
29
- # st.markdown("Convert between **Sketch ↔ Real Image** using your trained model.")
30
-
31
- # # --------------------------
32
- # # 🧠 Load Model
33
- # # --------------------------
34
- # @st.cache_resource
35
- # def load_model():
36
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- # model = torch.load("cyclegan_model.pth", map_location=device)
38
- # model.eval()
39
- # return model, device
40
-
41
- # model, device = load_model()
42
-
43
- # # --------------------------
44
- # # πŸ–Ό Image Processing Utils
45
- # # --------------------------
46
- # transform = transforms.Compose([
47
- # transforms.Resize((256, 256)),
48
- # transforms.ToTensor(),
49
- # transforms.Normalize((0.5,), (0.5,))
50
- # ])
51
-
52
- # def tensor_to_image(tensor):
53
- # tensor = tensor.squeeze(0).detach().cpu()
54
- # tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
55
- # return transforms.ToPILImage()(tensor)
56
-
57
- # # --------------------------
58
- # # πŸš€ UI Workflow
59
- # # --------------------------
60
- # uploaded_file = st.file_uploader("Upload an image (JPG or PNG)", type=["jpg", "jpeg", "png"])
61
-
62
- # if uploaded_file:
63
- # input_image = Image.open(uploaded_file).convert("RGB")
64
- # st.image(input_image, caption="Uploaded Image", use_container_width=True)
65
-
66
- # if st.button("✨ Generate"):
67
- # with st.spinner("Running the model... please wait ⏳"):
68
- # input_tensor = transform(input_image).unsqueeze(0).to(device)
69
- # with torch.no_grad():
70
- # output = model(input_tensor)
71
- # output_image = tensor_to_image(output)
72
-
73
- # st.image(output_image, caption="Generated Output", use_container_width=True)
74
-
75
- # # Option to download
76
- # buf = BytesIO()
77
- # output_image.save(buf, format="JPEG")
78
- # byte_im = buf.getvalue()
79
- # st.download_button("πŸ“₯ Download Result", data=byte_im, file_name="output.jpg", mime="image/jpeg")
80
-
81
- import streamlit as st
82
- import torch
83
- import torch.nn as nn
84
- from PIL import Image
85
- import numpy as np
86
- from torchvision import transforms
87
- import io
88
-
89
- # Set page config
90
- st.set_page_config(
91
- page_title="Face ↔ Sketch CycleGAN",
92
- page_icon="🎨",
93
- layout="wide"
94
- )
95
-
96
- # Generator Architecture (same as training)
97
- class ResidualBlock(nn.Module):
98
- def __init__(self, in_channels):
99
- super(ResidualBlock, self).__init__()
100
- self.block = nn.Sequential(
101
- nn.ReflectionPad2d(1),
102
- nn.Conv2d(in_channels, in_channels, kernel_size=3),
103
- nn.InstanceNorm2d(in_channels),
104
- nn.ReLU(inplace=True),
105
- nn.ReflectionPad2d(1),
106
- nn.Conv2d(in_channels, in_channels, kernel_size=3),
107
- nn.InstanceNorm2d(in_channels)
108
- )
109
-
110
- def forward(self, x):
111
- return x + self.block(x)
112
-
113
-
114
- class Generator(nn.Module):
115
- def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9):
116
- super(Generator, self).__init__()
117
-
118
- model = [
119
- nn.ReflectionPad2d(3),
120
- nn.Conv2d(input_channels, 64, kernel_size=7),
121
- nn.InstanceNorm2d(64),
122
- nn.ReLU(inplace=True)
123
- ]
124
-
125
- in_channels = 64
126
- out_channels = in_channels * 2
127
- for _ in range(2):
128
- model += [
129
- nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
130
- nn.InstanceNorm2d(out_channels),
131
- nn.ReLU(inplace=True)
132
- ]
133
- in_channels = out_channels
134
- out_channels = in_channels * 2
135
-
136
- for _ in range(num_residual_blocks):
137
- model += [ResidualBlock(in_channels)]
138
-
139
- out_channels = in_channels // 2
140
- for _ in range(2):
141
- model += [
142
- nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2,
143
- padding=1, output_padding=1),
144
- nn.InstanceNorm2d(out_channels),
145
- nn.ReLU(inplace=True)
146
- ]
147
- in_channels = out_channels
148
- out_channels = in_channels // 2
149
-
150
- model += [
151
- nn.ReflectionPad2d(3),
152
- nn.Conv2d(64, output_channels, kernel_size=7),
153
- nn.Tanh()
154
- ]
155
-
156
- self.model = nn.Sequential(*model)
157
-
158
- def forward(self, x):
159
- return self.model(x)
160
-
161
-
162
- # Cache models to avoid reloading
163
- @st.cache_resource
164
- def load_models():
165
- """Load both generator models"""
166
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
167
-
168
- # Load Photo β†’ Sketch model
169
- G_AB = Generator().to(device)
170
- checkpoint_ab = torch.load('photo_to_sketch.pth', map_location=device)
171
- G_AB.load_state_dict(checkpoint_ab['model_state_dict'])
172
- G_AB.eval()
173
-
174
- # Load Sketch β†’ Photo model
175
- G_BA = Generator().to(device)
176
- checkpoint_ba = torch.load('sketch_to_photo.pth', map_location=device)
177
- G_BA.load_state_dict(checkpoint_ba['model_state_dict'])
178
- G_BA.eval()
179
-
180
- return G_AB, G_BA, device
181
-
182
-
183
- def preprocess_image(image, target_size=256):
184
- """Preprocess image for model input"""
185
- transform = transforms.Compose([
186
- transforms.Resize((target_size, target_size)),
187
- transforms.ToTensor(),
188
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
189
- ])
190
-
191
- image = image.convert('RGB')
192
- return transform(image).unsqueeze(0)
193
-
194
-
195
- def postprocess_image(tensor):
196
- """Convert model output back to PIL Image"""
197
- image = tensor.cpu().squeeze().detach().numpy()
198
- image = image.transpose(1, 2, 0)
199
- image = (image * 0.5 + 0.5).clip(0, 1) # Denormalize
200
- image = (image * 255).astype(np.uint8)
201
- return Image.fromarray(image)
202
-
203
-
204
- def detect_image_type(image):
205
- """
206
- Simple heuristic to detect if image is a sketch or photo
207
- Sketches typically have higher contrast and less color variation
208
- """
209
- img_array = np.array(image.convert('L'))
210
-
211
- # Calculate statistics
212
- std_dev = np.std(img_array)
213
- mean_val = np.mean(img_array)
214
-
215
- # Sketches tend to have higher std deviation and be closer to extremes
216
- if std_dev > 80 and (mean_val > 180 or mean_val < 100):
217
- return "sketch"
218
- else:
219
- return "photo"
220
-
221
-
222
- def convert_image(image, model, device):
223
- """Convert image using the specified model"""
224
- input_tensor = preprocess_image(image).to(device)
225
-
226
- with torch.no_grad():
227
- output_tensor = model(input_tensor)
228
-
229
- return postprocess_image(output_tensor)
230
-
231
-
232
- # Main App
233
- def main():
234
- st.title("🎨 Face ↔ Sketch CycleGAN")
235
- st.markdown("Convert photos to sketches and sketches to photos using CycleGAN")
236
-
237
- # Load models
238
- try:
239
- G_AB, G_BA, device = load_models()
240
- st.success(f"βœ… Models loaded successfully! Using: {device}")
241
- except Exception as e:
242
- st.error(f"❌ Error loading models: {str(e)}")
243
- st.stop()
244
-
245
- # Sidebar
246
- st.sidebar.header("βš™οΈ Settings")
247
- conversion_mode = st.sidebar.radio(
248
- "Conversion Mode",
249
- ["Auto-detect", "Photo β†’ Sketch", "Sketch β†’ Photo"],
250
- help="Auto-detect will automatically determine the input type"
251
- )
252
-
253
- # Main content
254
- col1, col2 = st.columns(2)
255
-
256
- with col1:
257
- st.header("πŸ“€ Input")
258
- upload_method = st.radio("Upload method:", ["Upload Image", "Use Camera"])
259
-
260
- if upload_method == "Upload Image":
261
- uploaded_file = st.file_uploader(
262
- "Choose an image...",
263
- type=['png', 'jpg', 'jpeg'],
264
- help="Upload a photo or sketch"
265
- )
266
-
267
- if uploaded_file is not None:
268
- input_image = Image.open(uploaded_file)
269
- st.image(input_image, caption="Input Image", use_container_width=True)
270
- else:
271
- camera_photo = st.camera_input("Take a picture")
272
- if camera_photo is not None:
273
- input_image = Image.open(camera_photo)
274
- st.image(input_image, caption="Captured Image", use_container_width=True)
275
- else:
276
- input_image = None
277
-
278
- with col2:
279
- st.header("πŸ“₯ Output")
280
-
281
- if 'input_image' in locals() and input_image is not None:
282
- # Determine conversion direction
283
- if conversion_mode == "Auto-detect":
284
- detected_type = detect_image_type(input_image)
285
- st.info(f"πŸ” Detected: {detected_type.upper()}")
286
-
287
- if detected_type == "photo":
288
- output_image = convert_image(input_image, G_AB, device)
289
- conversion_text = "Photo β†’ Sketch"
290
- else:
291
- output_image = convert_image(input_image, G_BA, device)
292
- conversion_text = "Sketch β†’ Photo"
293
-
294
- elif conversion_mode == "Photo β†’ Sketch":
295
- output_image = convert_image(input_image, G_AB, device)
296
- conversion_text = "Photo β†’ Sketch"
297
-
298
- else: # Sketch β†’ Photo
299
- output_image = convert_image(input_image, G_BA, device)
300
- conversion_text = "Sketch β†’ Photo"
301
-
302
- st.image(output_image, caption=f"Output ({conversion_text})", use_container_width=True)
303
-
304
- # Download button
305
- buf = io.BytesIO()
306
- output_image.save(buf, format="PNG")
307
- byte_im = buf.getvalue()
308
-
309
- st.download_button(
310
- label="⬇️ Download Result",
311
- data=byte_im,
312
- file_name=f"cyclegan_output_{conversion_text.replace(' β†’ ', '_to_')}.png",
313
- mime="image/png"
314
- )
315
- else:
316
- st.info("πŸ‘† Upload or capture an image to see the conversion")
317
-
318
- # Information section
319
- with st.expander("ℹ️ About this app"):
320
- st.markdown("""
321
- ### CycleGAN Face-Sketch Converter
322
-
323
- This application uses CycleGAN (Cycle-Consistent Generative Adversarial Networks)
324
- to convert between face photos and sketches.
325
-
326
- **Features:**
327
- - 🎨 Photo to Sketch conversion
328
- - πŸ–ΌοΈ Sketch to Photo conversion
329
- - πŸ” Automatic input type detection
330
- - πŸ“Έ Camera support
331
-
332
- **How it works:**
333
- CycleGAN learns to translate images between two domains without paired examples.
334
- It uses cycle consistency loss to ensure the translation is meaningful.
335
-
336
- **Model Details:**
337
- - Architecture: ResNet-based Generator
338
- - Training: Unpaired face-sketch dataset
339
- - Image size: 256x256 pixels
340
- """)
341
-
342
- # Footer
343
- st.markdown("---")
344
- st.markdown(
345
- "<div style='text-align: center'>Made with ❀️ using Streamlit and PyTorch</div>",
346
- unsafe_allow_html=True
347
- )
348
-
349
-
350
- if __name__ == "__main__":
351
  main()
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import io
8
+
9
+ # Set page config
10
+ st.set_page_config(
11
+ page_title="Face ↔ Sketch CycleGAN",
12
+ page_icon="🎨",
13
+ layout="wide"
14
+ )
15
+
16
+ # Generator Architecture (same as training)
17
+ class ResidualBlock(nn.Module):
18
+ def __init__(self, in_channels):
19
+ super(ResidualBlock, self).__init__()
20
+ self.block = nn.Sequential(
21
+ nn.ReflectionPad2d(1),
22
+ nn.Conv2d(in_channels, in_channels, kernel_size=3),
23
+ nn.InstanceNorm2d(in_channels),
24
+ nn.ReLU(inplace=True),
25
+ nn.ReflectionPad2d(1),
26
+ nn.Conv2d(in_channels, in_channels, kernel_size=3),
27
+ nn.InstanceNorm2d(in_channels)
28
+ )
29
+
30
+ def forward(self, x):
31
+ return x + self.block(x)
32
+
33
+
34
+ class Generator(nn.Module):
35
+ def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9):
36
+ super(Generator, self).__init__()
37
+
38
+ model = [
39
+ nn.ReflectionPad2d(3),
40
+ nn.Conv2d(input_channels, 64, kernel_size=7),
41
+ nn.InstanceNorm2d(64),
42
+ nn.ReLU(inplace=True)
43
+ ]
44
+
45
+ in_channels = 64
46
+ out_channels = in_channels * 2
47
+ for _ in range(2):
48
+ model += [
49
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
50
+ nn.InstanceNorm2d(out_channels),
51
+ nn.ReLU(inplace=True)
52
+ ]
53
+ in_channels = out_channels
54
+ out_channels = in_channels * 2
55
+
56
+ for _ in range(num_residual_blocks):
57
+ model += [ResidualBlock(in_channels)]
58
+
59
+ out_channels = in_channels // 2
60
+ for _ in range(2):
61
+ model += [
62
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2,
63
+ padding=1, output_padding=1),
64
+ nn.InstanceNorm2d(out_channels),
65
+ nn.ReLU(inplace=True)
66
+ ]
67
+ in_channels = out_channels
68
+ out_channels = in_channels // 2
69
+
70
+ model += [
71
+ nn.ReflectionPad2d(3),
72
+ nn.Conv2d(64, output_channels, kernel_size=7),
73
+ nn.Tanh()
74
+ ]
75
+
76
+ self.model = nn.Sequential(*model)
77
+
78
+ def forward(self, x):
79
+ return self.model(x)
80
+
81
+
82
+ # Cache models to avoid reloading
83
+ @st.cache_resource
84
+ def load_models():
85
+ """Load both generator models"""
86
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87
+
88
+ # Load Photo β†’ Sketch model
89
+ G_AB = Generator().to(device)
90
+ checkpoint_ab = torch.load('photo_to_sketch.pth', map_location=device)
91
+ G_AB.load_state_dict(checkpoint_ab['model_state_dict'])
92
+ G_AB.eval()
93
+
94
+ # Load Sketch β†’ Photo model
95
+ G_BA = Generator().to(device)
96
+ checkpoint_ba = torch.load('sketch_to_photo.pth', map_location=device)
97
+ G_BA.load_state_dict(checkpoint_ba['model_state_dict'])
98
+ G_BA.eval()
99
+
100
+ return G_AB, G_BA, device
101
+
102
+
103
+ def preprocess_image(image, target_size=256):
104
+ """Preprocess image for model input"""
105
+ transform = transforms.Compose([
106
+ transforms.Resize((target_size, target_size)),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
109
+ ])
110
+
111
+ image = image.convert('RGB')
112
+ return transform(image).unsqueeze(0)
113
+
114
+
115
+ def postprocess_image(tensor):
116
+ """Convert model output back to PIL Image"""
117
+ image = tensor.cpu().squeeze().detach().numpy()
118
+ image = image.transpose(1, 2, 0)
119
+ image = (image * 0.5 + 0.5).clip(0, 1) # Denormalize
120
+ image = (image * 255).astype(np.uint8)
121
+ return Image.fromarray(image)
122
+
123
+
124
+ def detect_image_type(image):
125
+ """
126
+ Simple heuristic to detect if image is a sketch or photo
127
+ Sketches typically have higher contrast and less color variation
128
+ """
129
+ img_array = np.array(image.convert('L'))
130
+
131
+ # Calculate statistics
132
+ std_dev = np.std(img_array)
133
+ mean_val = np.mean(img_array)
134
+
135
+ # Sketches tend to have higher std deviation and be closer to extremes
136
+ if std_dev > 80 and (mean_val > 180 or mean_val < 100):
137
+ return "sketch"
138
+ else:
139
+ return "photo"
140
+
141
+
142
+ def convert_image(image, model, device):
143
+ """Convert image using the specified model"""
144
+ input_tensor = preprocess_image(image).to(device)
145
+
146
+ with torch.no_grad():
147
+ output_tensor = model(input_tensor)
148
+
149
+ return postprocess_image(output_tensor)
150
+
151
+
152
+ # Main App
153
+ def main():
154
+ st.title("🎨 Face ↔ Sketch CycleGAN")
155
+ st.markdown("Convert photos to sketches and sketches to photos using CycleGAN")
156
+
157
+ # Load models
158
+ try:
159
+ G_AB, G_BA, device = load_models()
160
+ st.success(f"βœ… Models loaded successfully! Using: {device}")
161
+ except Exception as e:
162
+ st.error(f"❌ Error loading models: {str(e)}")
163
+ st.stop()
164
+
165
+ # Sidebar
166
+ st.sidebar.header("βš™οΈ Settings")
167
+ conversion_mode = st.sidebar.radio(
168
+ "Conversion Mode",
169
+ ["Auto-detect", "Photo β†’ Sketch", "Sketch β†’ Photo"],
170
+ help="Auto-detect will automatically determine the input type"
171
+ )
172
+
173
+ # Main content
174
+ col1, col2 = st.columns(2)
175
+
176
+ with col1:
177
+ st.header("πŸ“€ Input")
178
+ upload_method = st.radio("Upload method:", ["Upload Image", "Use Camera"])
179
+
180
+ if upload_method == "Upload Image":
181
+ uploaded_file = st.file_uploader(
182
+ "Choose an image...",
183
+ type=['png', 'jpg', 'jpeg'],
184
+ help="Upload a photo or sketch"
185
+ )
186
+
187
+ if uploaded_file is not None:
188
+ input_image = Image.open(uploaded_file)
189
+ st.image(input_image, caption="Input Image", use_column_width=True)
190
+ else:
191
+ camera_photo = st.camera_input("Take a picture")
192
+ if camera_photo is not None:
193
+ input_image = Image.open(camera_photo)
194
+ st.image(input_image, caption="Captured Image", use_column_width=True)
195
+ else:
196
+ input_image = None
197
+
198
+ with col2:
199
+ st.header("πŸ“₯ Output")
200
+
201
+ if 'input_image' in locals() and input_image is not None:
202
+ # Determine conversion direction
203
+ if conversion_mode == "Auto-detect":
204
+ detected_type = detect_image_type(input_image)
205
+ st.info(f"πŸ” Detected: {detected_type.upper()}")
206
+
207
+ if detected_type == "photo":
208
+ output_image = convert_image(input_image, G_AB, device)
209
+ conversion_text = "Photo β†’ Sketch"
210
+ else:
211
+ output_image = convert_image(input_image, G_BA, device)
212
+ conversion_text = "Sketch β†’ Photo"
213
+
214
+ elif conversion_mode == "Photo β†’ Sketch":
215
+ output_image = convert_image(input_image, G_AB, device)
216
+ conversion_text = "Photo β†’ Sketch"
217
+
218
+ else: # Sketch β†’ Photo
219
+ output_image = convert_image(input_image, G_BA, device)
220
+ conversion_text = "Sketch β†’ Photo"
221
+
222
+ st.image(output_image, caption=f"Output ({conversion_text})", use_column_width=True)
223
+
224
+ # Download button
225
+ buf = io.BytesIO()
226
+ output_image.save(buf, format="PNG")
227
+ byte_im = buf.getvalue()
228
+
229
+ st.download_button(
230
+ label="⬇️ Download Result",
231
+ data=byte_im,
232
+ file_name=f"cyclegan_output_{conversion_text.replace(' β†’ ', '_to_')}.png",
233
+ mime="image/png"
234
+ )
235
+ else:
236
+ st.info("πŸ‘† Upload or capture an image to see the conversion")
237
+
238
+ # Information section
239
+ with st.expander("ℹ️ About this app"):
240
+ st.markdown("""
241
+ ### CycleGAN Face-Sketch Converter
242
+
243
+ This application uses CycleGAN (Cycle-Consistent Generative Adversarial Networks)
244
+ to convert between face photos and sketches.
245
+
246
+ **Features:**
247
+ - 🎨 Photo to Sketch conversion
248
+ - πŸ–ΌοΈ Sketch to Photo conversion
249
+ - πŸ” Automatic input type detection
250
+ - πŸ“Έ Camera support
251
+
252
+ **How it works:**
253
+ CycleGAN learns to translate images between two domains without paired examples.
254
+ It uses cycle consistency loss to ensure the translation is meaningful.
255
+
256
+ **Model Details:**
257
+ - Architecture: ResNet-based Generator
258
+ - Training: Unpaired face-sketch dataset
259
+ - Image size: 256x256 pixels
260
+ """)
261
+
262
+ # Footer
263
+ st.markdown("---")
264
+ st.markdown(
265
+ "<div style='text-align: center'>Made with ❀️ using Streamlit and PyTorch</div>",
266
+ unsafe_allow_html=True
267
+ )
268
+
269
+
270
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  main()