"""
Disk-based storage manager for tensor and GPU operations.
Uses memory-mapped files to handle large data without RAM usage.
"""

import os
import numpy as np
import mmap
from typing import Dict, Any, Tuple
import hashlib

class DiskStorageManager:
    def __init__(self, storage_dir: str = "storage/tensor_data"):
        self.storage_dir = storage_dir
        self.active_mappings: Dict[str, mmap.mmap] = {}
        os.makedirs(storage_dir, exist_ok=True)
        
    def create_storage(self, name: str, shape: Tuple[int, ...], dtype=np.float32) -> str:
        """Create a memory-mapped file for storing tensor data"""
        filename = os.path.join(self.storage_dir, f"{name}.dat")
        size = int(np.prod(shape)) * np.dtype(dtype).itemsize
        
        # Create file of required size
        with open(filename, 'wb') as f:
            f.seek(size - 1)
            f.write(b'\0')
            
        return filename
        
    def get_mapping(self, filename: str, shape: Tuple[int, ...], dtype=np.float32, mode='r+') -> np.ndarray:
        """Get memory-mapped array for given file"""
        if filename not in self.active_mappings:
            fd = os.open(filename, os.O_RDWR)
            mapping = mmap.mmap(fd, 0, access=mmap.ACCESS_WRITE)
            self.active_mappings[filename] = mapping
            
        return np.frombuffer(self.active_mappings[filename], dtype=dtype).reshape(shape)
        
    def store_batch(self, name: str, data: np.ndarray):
        """Store a batch of data to disk"""
        filename = self.create_storage(name, data.shape, data.dtype)
        mapped_array = self.get_mapping(filename, data.shape, data.dtype)
        mapped_array[:] = data[:]
        
    def load_batch(self, name: str, shape: Tuple[int, ...], dtype=np.float32) -> np.ndarray:
        """Load a batch of data from disk"""
        filename = os.path.join(self.storage_dir, f"{name}.dat")
        return self.get_mapping(filename, shape, dtype)
        
    def cleanup(self):
        """Clean up all memory mappings"""
        for mapping in self.active_mappings.values():
            mapping.close()
        self.active_mappings.clear()