"""
HuggingFace Cache Manager - Finds, manages and optimizes HuggingFace cache
Handles both model and dataset caches across different platforms
"""

import os
import shutil
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
import platform
import tempfile
from datetime import datetime, timedelta

class HFCacheManager:
    def __init__(self, verbose: bool = False):
        """Initialize the HuggingFace cache manager"""
        self.verbose = verbose
        self.logger = self._setup_logging()
        
        # Default cache locations by platform
        self.platform = platform.system().lower()
        self.cache_roots = self._get_default_cache_roots()
        
        # Cache subdirectories to check
        self.cache_types = {
            "models": ["models--*", "hub", ".cache/huggingface/hub"],
            "datasets": ["datasets", ".cache/huggingface/datasets"],
            "accelerate": ["accelerate", ".cache/huggingface/accelerate"],
            "transformers": ["transformers", ".cache/huggingface/transformers"]
        }
        
        # Find all cache locations
        self.found_caches = self._discover_caches()
        
    def _setup_logging(self) -> logging.Logger:
        """Set up logging configuration"""
        logger = logging.getLogger("HFCacheManager")
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        logger.setLevel(logging.INFO if self.verbose else logging.WARNING)
        return logger

    def _get_default_cache_roots(self) -> List[Path]:
        """Get default cache root directories based on platform"""
        roots = []
        
        if self.platform == "windows":
            # Windows default locations
            appdata = os.getenv("APPDATA", "")
            localappdata = os.getenv("LOCALAPPDATA", "")
            userprofile = os.getenv("USERPROFILE", "")
            
            if appdata:
                roots.append(Path(appdata) / "huggingface")
            if localappdata:
                roots.append(Path(localappdata) / "huggingface")
            if userprofile:
                roots.append(Path(userprofile) / ".cache" / "huggingface")
                
        elif self.platform == "darwin":
            # macOS default locations
            home = Path.home()
            roots.extend([
                home / "Library" / "Caches" / "huggingface",
                home / ".cache" / "huggingface"
            ])
            
        else:
            # Linux/Unix default locations
            xdg_cache = os.getenv("XDG_CACHE_HOME", "")
            if xdg_cache:
                roots.append(Path(xdg_cache) / "huggingface")
            
            home = Path.home()
            roots.append(home / ".cache" / "huggingface")
        
        # Add environment variable overrides
        hf_home = os.getenv("HF_HOME")
        if hf_home:
            roots.append(Path(hf_home))
            
        hf_cache_home = os.getenv("HF_CACHE_HOME")
        if hf_cache_home:
            roots.append(Path(hf_cache_home))
            
        transformers_cache = os.getenv("TRANSFORMERS_CACHE")
        if transformers_cache:
            roots.append(Path(transformers_cache))
            
        datasets_cache = os.getenv("HF_DATASETS_CACHE")
        if datasets_cache:
            roots.append(Path(datasets_cache))
            
        return list(set(roots))  # Remove duplicates

    def _discover_caches(self) -> Dict[str, List[Path]]:
        """Find all HuggingFace cache directories"""
        found = {cache_type: [] for cache_type in self.cache_types}
        
        for root in self.cache_roots:
            if not root.exists():
                continue
                
            self.logger.debug(f"Searching for caches in: {root}")
            
            # Search for each cache type
            for cache_type, patterns in self.cache_types.items():
                for pattern in patterns:
                    # Handle glob patterns
                    if "*" in pattern:
                        matches = list(root.glob(pattern))
                        found[cache_type].extend(matches)
                    else:
                        cache_path = root / pattern
                        if cache_path.exists():
                            found[cache_type].append(cache_path)
                            
        return found

    def get_cache_size(self, cache_type: Optional[str] = None) -> Dict[str, int]:
        """Get size of cache directories in bytes"""
        sizes = {}
        
        if cache_type:
            # Get size for specific cache type
            if cache_type not in self.cache_types:
                raise ValueError(f"Invalid cache type: {cache_type}")
            paths = self.found_caches[cache_type]
            total = sum(
                sum(f.stat().st_size for f in p.rglob("*") if f.is_file())
                for p in paths
            )
            sizes[cache_type] = total
        else:
            # Get size for all cache types
            for cache_type, paths in self.found_caches.items():
                total = sum(
                    sum(f.stat().st_size for f in p.rglob("*") if f.is_file())
                    for p in paths
                )
                sizes[cache_type] = total
                
        return sizes

    def clean_cache(
        self,
        cache_type: Optional[str] = None,
        older_than: Optional[int] = None,
        min_size: Optional[int] = None
    ) -> Dict[str, int]:
        """Clean cache directories based on age and size criteria
        
        Args:
            cache_type: Specific cache to clean, or None for all
            older_than: Remove files older than N days
            min_size: Remove files larger than N bytes
        
        Returns:
            Dict of bytes freed per cache type
        """
        freed = {cache_type: 0 for cache_type in self.cache_types}
        now = datetime.now()
        
        cache_types_to_clean = [cache_type] if cache_type else self.cache_types.keys()
        
        for cache_type in cache_types_to_clean:
            for cache_dir in self.found_caches[cache_type]:
                if not cache_dir.exists():
                    continue
                    
                for path in cache_dir.rglob("*"):
                    if not path.is_file():
                        continue
                        
                    should_delete = False
                    
                    # Check age
                    if older_than is not None:
                        mtime = datetime.fromtimestamp(path.stat().st_mtime)
                        age = now - mtime
                        if age > timedelta(days=older_than):
                            should_delete = True
                            
                    # Check size
                    if min_size is not None:
                        if path.stat().st_size > min_size:
                            should_delete = True
                            
                    if should_delete:
                        try:
                            size = path.stat().st_size
                            path.unlink()
                            freed[cache_type] += size
                            self.logger.info(f"Removed: {path}")
                        except Exception as e:
                            self.logger.error(f"Failed to remove {path}: {e}")
                            
        return freed

    def get_cache_info(self) -> Dict[str, Dict]:
        """Get detailed information about all caches"""
        info = {}
        
        for cache_type, paths in self.found_caches.items():
            cache_info = {
                "locations": [str(p) for p in paths],
                "size_bytes": self.get_cache_size(cache_type)[cache_type],
                "file_count": sum(
                    sum(1 for _ in p.rglob("*") if _.is_file())
                    for p in paths
                ),
                "exists": any(p.exists() for p in paths)
            }
            info[cache_type] = cache_info
            
        return info

    def export_cache_info(self, output_path: Union[str, Path]):
        """Export cache information to JSON file"""
        info = self.get_cache_info()
        
        # Add metadata
        info["_metadata"] = {
            "timestamp": datetime.now().isoformat(),
            "platform": self.platform,
            "cache_roots": [str(p) for p in self.cache_roots]
        }
        
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with output_path.open("w") as f:
            json.dump(info, f, indent=2)
            
        self.logger.info(f"Cache info exported to: {output_path}")

def main():
    """CLI interface for cache management"""
    cache_manager = HFCacheManager(verbose=True)
    
    print("\nHuggingFace Cache Analysis")
    print("=" * 50)
    
    info = cache_manager.get_cache_info()
    
    for cache_type, details in info.items():
        print(f"\n{cache_type.upper()}:")
        print(f"  Size: {details['size_bytes'] / 1024 / 1024:.2f} MB")
        print(f"  Files: {details['file_count']}")
        print("  Locations:")
        for loc in details['locations']:
            print(f"    - {loc}")
            
    # Export results
    cache_manager.export_cache_info("hf_cache_info.json")

if __name__ == "__main__":
    main()