"""
Secure Storage for Bitcoin Wallet Keys
Handles encrypted storage of private keys
"""

import os
import json
import base64
from pathlib import Path
from typing import Optional, Dict
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

class SecureKeyStorage:
    def __init__(self, storage_dir: str = "secure_wallets"):
        """Initialize secure key storage"""
        self.storage_dir = Path(storage_dir)
        self.storage_dir.mkdir(exist_ok=True)
        
    def _generate_key(self, password: str, salt: Optional[bytes] = None) -> tuple[bytes, bytes]:
        """Generate encryption key from password"""
        if salt is None:
            salt = os.urandom(16)
            
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=480000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
        return key, salt
        
    def save_wallet(self, address: str, private_key: str, password: str):
        """
        Save encrypted wallet keys
        Args:
            address: Bitcoin address
            private_key: Private key in hex format
            password: Encryption password
        """
        # Generate encryption key
        key, salt = self._generate_key(password)
        fernet = Fernet(key)
        
        # Encrypt private key
        encrypted_key = fernet.encrypt(private_key.encode())
        
        # Save encrypted data with salt
        wallet_data = {
            'salt': base64.b64encode(salt).decode(),
            'encrypted_key': encrypted_key.decode()
        }
        
        wallet_file = self.storage_dir / f"{address}.enc"
        with open(wallet_file, 'w') as f:
            json.dump(wallet_data, f)
            
    def load_wallet(self, address: str, password: str) -> Optional[str]:
        """
        Load and decrypt wallet keys
        Args:
            address: Bitcoin address
            password: Encryption password
        Returns:
            Decrypted private key in hex format
        """
        wallet_file = self.storage_dir / f"{address}.enc"
        if not wallet_file.exists():
            return None
            
        try:
            # Load encrypted data
            with open(wallet_file, 'r') as f:
                wallet_data = json.load(f)
                
            # Reconstruct encryption key
            salt = base64.b64decode(wallet_data['salt'])
            key, _ = self._generate_key(password, salt)
            fernet = Fernet(key)
            
            # Decrypt private key
            encrypted_key = wallet_data['encrypted_key'].encode()
            private_key = fernet.decrypt(encrypted_key).decode()
            
            return private_key
        except Exception as e:
            print(f"Error decrypting wallet: {e}")
            return None
            
    def list_wallets(self) -> list[str]:
        """List all stored wallet addresses"""
        return [f.stem for f in self.storage_dir.glob("*.enc")]
