import os import random import numpy as np import torch import torch.nn.functional as F import torch.nn as nn import torch.optim as optim from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter import wandb import matplotlib.pyplot as plt from torch.utils.data import DataLoader from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score from datalib import FakeMusicCapsDataset from datalib import ( FakeMusicCapsDataset, train_files, val_files, train_labels, val_labels, closed_test_files, closed_test_labels, open_test_files, open_test_labels, preprocess_audio ) from datalib import preprocess_audio from networks import CCV from attentionmap import visualize_attention_map from confusion_matrix import plot_confusion_matrix def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) ''' python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True audiocnn encoder - crossattn based decoder (ViT) model ''' # Argument parsing import argparse parser = argparse.ArgumentParser(description='AI Music Detection Training') parser.add_argument('--gpu', type=str, default='1', help='GPU ID') parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') parser.add_argument('--epochs', type=int, default=10, help='Number of epochs') parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds') parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience') parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory') parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory') parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)") parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type") parser.add_argument('--inference', type=str, help='Path to a .wav file for inference') parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(42) random.seed(42) np.random.seed(42) wandb.init(project="", name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args) if args.model_name == 'CCV': model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda() feat_type = 'mel' else: raise ValueError(f"Invalid model name: {args.model_name}") model = model.to(device) print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}") print(f"weight_decay WD: {args.weight_decay}") optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) if args.loss_type == "ce": print("Using CrossEntropyLoss") criterion = nn.CrossEntropyLoss() elif args.loss_type == "weighted_ce": print("Using Weighted CrossEntropyLoss") num_real = sum(1 for label in train_labels if label == 0) num_fake = sum(1 for label in train_labels if label == 1) total_samples = num_real + num_fake weight_real = total_samples / (2 * num_real) weight_fake = total_samples / (2 * num_fake) class_weights = torch.tensor([weight_real, weight_fake]).to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) elif args.loss_type == "focal": print("Using Focal Loss") class FocalLoss(torch.nn.Module): def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss criterion = FocalLoss().to(device) if not os.path.exists(args.ckpt_path): os.makedirs(args.ckpt_path) train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration) val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args): writer = SummaryWriter(log_dir=args.log_dir) best_val_bal_acc = float('inf') early_stop_cnt = 0 log_interval = 1 for epoch in range(args.epochs): print(f"\n[Epoch {epoch + 1}/{args.epochs}]") model.train() train_loss, train_correct, train_total = 0, 0, 0 all_train_preds= [] all_train_labels = [] attention_maps = [] train_pbar = tqdm(train_loader, desc="Train", leave=False) for batch_idx, (data, target) in enumerate(train_pbar): data = data.to(device) target = target.to(device) output = model(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() * data.size(0) preds = output.argmax(dim=1) train_correct += (preds == target).sum().item() train_total += target.size(0) all_train_labels.extend(target.cpu().numpy()) all_train_preds.extend(preds.cpu().numpy()) if hasattr(model, "get_attention_maps"): attention_maps.append(model.get_attention_maps()) train_loss /= train_total train_acc = train_correct / train_total train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds) train_precision = precision_score(all_train_labels, all_train_preds, average="binary") train_recall = recall_score(all_train_labels, all_train_preds, average="binary") train_f1 = f1_score(all_train_labels, all_train_preds, average="binary") wandb.log({ "Train Loss": train_loss, "Train Accuracy": train_acc, "Train Precision": train_precision, "Train Recall": train_recall, "Train F1 Score": train_f1, "Train B_ACC": train_bal_acc, }) print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | " f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | " f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}") model.eval() val_loss, val_correct, val_total = 0, 0, 0 all_val_preds, all_val_labels = [], [] attention_maps = [] val_pbar = tqdm(val_loader, desc=" Val ", leave=False) with torch.no_grad(): for data, target in val_pbar: data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) val_loss += loss.item() * data.size(0) preds = output.argmax(dim=1) val_correct += (preds == target).sum().item() val_total += target.size(0) all_val_labels.extend(target.cpu().numpy()) all_val_preds.extend(preds.cpu().numpy()) if hasattr(model, "get_attention_maps"): attention_maps.append(model.get_attention_maps()) val_loss /= val_total val_acc = val_correct / val_total val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds) val_precision = precision_score(all_val_labels, all_val_preds, average="binary") val_recall = recall_score(all_val_labels, all_val_preds, average="binary") val_f1 = f1_score(all_val_labels, all_val_preds, average="binary") wandb.log({ "Validation Loss": val_loss, "Validation Accuracy": val_acc, "Validation Precision": val_precision, "Validation Recall": val_recall, "Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc, }) print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} " f"({100. * batch_idx / len(val_loader):.0f}%)]\t" f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | " f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | " f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}") if epoch % 1 == 0 and len(attention_maps) > 0: print(f"Visualizing Attention Map at Epoch {epoch+1}") if isinstance(attention_maps[0], list): attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]]) elif isinstance(attention_maps[0], torch.Tensor): attn_map_numpy = attention_maps[0].detach().cpu().numpy() else: attn_map_numpy = np.array(attention_maps[0]) print(f"Attention Map Shape: {attn_map_numpy.shape}") if len(attn_map_numpy) > 0: fig, ax = plt.subplots(figsize=(10, 8)) ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest') ax.set_title(f"Attention Map - Epoch {epoch+1}") plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis')) plt.savefig("") plt.show() else: print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}") if val_bal_acc < best_val_bal_acc: best_val_bal_acc = val_bal_acc early_stop_cnt = 0 torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth")) print("Best model saved.") else: early_stop_cnt += 1 print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}') if early_stop_cnt >= args.patience_counter: print("Early stopping triggered.") break scheduler.step() plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch) wandb.finish() writer.close() def predict(audio_path): print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth") model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) model.eval() input_tensor = preprocess_audio(audio_path).to(device) with torch.no_grad(): output = model(input_tensor) probabilities = F.softmax(output, dim=1) ai_music_prob = probabilities[0, 1].item() if ai_music_prob > 0.5: print(f"FAKE MUSIC {ai_music_prob:.2%})") else: print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%") def Test(model, test_loader, criterion, device): model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) model.eval() test_loss, test_correct, test_total = 0, 0, 0 all_preds, all_labels = [], [] with torch.no_grad(): for data, target in tqdm(test_loader, desc=" Test ", leave=False): data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) test_loss += loss.item() * data.size(0) preds = output.argmax(dim=1) test_correct += (preds == target).sum().item() test_total += target.size(0) all_labels.extend(target.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) test_loss /= test_total test_acc = test_correct / test_total test_bal_acc = balanced_accuracy_score(all_labels, all_preds) test_precision = precision_score(all_labels, all_preds, average="binary") test_recall = recall_score(all_labels, all_preds, average="binary") test_f1 = f1_score(all_labels, all_preds, average="binary") print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") if __name__ == "__main__": train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args) if args.closed_test: print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) elif args.open_test: print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) else: print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) print("\nEvaluating Model on Test Set...") Test(model, test_loader, criterion, device) if args.inference: if not os.path.exists(args.inference): print(f"[ERROR] No File Found: {args.inference}") else: predict(args.inference)