"""
Base module class for Helium deep learning framework
"""
from typing import Dict, List, Optional, Union, Any, Tuple
from virtual_gpu_driver.src.driver_api import VirtualGPUDriver
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout, Tensor

class HeliumModule:
    """Base class for all neural network modules in Helium"""
    
    def __init__(self):
        self.training = True
        self.device = None
        self._parameters = {}
        self._buffers = {}
        self._modules = {}
        
    def parameters(self) -> Dict[str, Tensor]:
        """Return all parameters in the module"""
        params = {}
        for name, param in self._parameters.items():
            params[name] = param
            
        for name, module in self._modules.items():
            module_params = module.parameters()
            for param_name, param in module_params.items():
                params[f"{name}.{param_name}"] = param
                
        return params
        
    def to(self, device: Device) -> 'HeliumModule':
        """Move module to specified device"""
        self.device = device
        for param in self._parameters.values():
            param.to(device)
        for buffer in self._buffers.values():
            buffer.to(device)
        for module in self._modules.values():
            module.to(device)
        return self
        
    def train(self, mode: bool = True):
        """Set training mode"""
        self.training = mode
        for module in self._modules.values():
            module.train(mode)
        return self
        
    def eval(self):
        """Set evaluation mode"""
        return self.train(False)
        
    def register_parameter(self, name: str, param: Tensor):
        """Register a parameter with the module"""
        if name in self._parameters:
            raise KeyError(f"Parameter {name} already registered")
        self._parameters[name] = param
        
    def register_buffer(self, name: str, buffer: Tensor):
        """Register a persistent buffer"""
        if name in self._buffers:
            raise KeyError(f"Buffer {name} already registered")
        self._buffers[name] = buffer
        
    def add_module(self, name: str, module: 'HeliumModule'):
        """Add a child module"""
        if name in self._modules:
            raise KeyError(f"Module {name} already added")
        self._modules[name] = module
