"""
3D transformation module for GPU-accelerated geometry processing
"""
from typing import Tuple, Optional, Union
import numpy as np

class Transform3D:
    """GPU-accelerated 3D transformation operations"""
    
    def __init__(self, driver):
        self.driver = driver
        self.prefix = "transform3d"
        
    def projection_matrix(
        self, 
        fov_y: float = 60.0,
        aspect: float = 1.0,
        near: float = 0.1,
        far: float = 1000.0,
        perspective: bool = True
    ) -> str:
        """Create projection matrix in driver memory"""
        if perspective:
            # Perspective projection
            f = 1.0 / np.tan(np.radians(fov_y) / 2.0)
            mat = np.zeros((4, 4), dtype=np.float32)
            mat[0, 0] = f / aspect
            mat[1, 1] = f
            mat[2, 2] = (far + near) / (near - far)
            mat[2, 3] = 2.0 * far * near / (near - far)
            mat[3, 2] = -1.0
        else:
            # Orthographic projection
            mat = np.zeros((4, 4), dtype=np.float32)
            mat[0, 0] = 2.0 / aspect
            mat[1, 1] = 2.0
            mat[2, 2] = -2.0 / (far - near)
            mat[2, 3] = -(far + near) / (far - near)
            mat[3, 3] = 1.0
            
        mat_name = f"{self.prefix}_proj_matrix"
        self.driver.create_tensor(mat_name, mat)
        return mat_name
        
    def view_matrix(
        self,
        eye: Tuple[float, float, float],
        target: Tuple[float, float, float],
        up: Tuple[float, float, float] = (0, 1, 0)
    ) -> str:
        """Create view matrix in driver memory"""
        eye = np.array(eye, dtype=np.float32)
        target = np.array(target, dtype=np.float32)
        up = np.array(up, dtype=np.float32)
        
        # Calculate forward (negative z), right, and up vectors
        forward = target - eye
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)
        
        up = np.cross(right, forward)
        up = up / np.linalg.norm(up)
        
        # Build view matrix
        mat = np.zeros((4, 4), dtype=np.float32)
        mat[0, :3] = right
        mat[1, :3] = up
        mat[2, :3] = -forward
        mat[0, 3] = -np.dot(right, eye)
        mat[1, 3] = -np.dot(up, eye)
        mat[2, 3] = np.dot(forward, eye)
        mat[3, 3] = 1.0
        
        mat_name = f"{self.prefix}_view_matrix"
        self.driver.create_tensor(mat_name, mat)
        return mat_name
        
    def model_matrix(
        self,
        translation: Tuple[float, float, float] = (0, 0, 0),
        rotation: Tuple[float, float, float] = (0, 0, 0),
        scale: Union[float, Tuple[float, float, float]] = 1.0
    ) -> str:
        """Create model matrix in driver memory"""
        # Scale matrix
        if isinstance(scale, (int, float)):
            scale = (scale, scale, scale)
        sx, sy, sz = scale
        
        # Rotation matrices (Euler angles in degrees)
        rx, ry, rz = map(np.radians, rotation)
        
        rot_x = np.array([
            [1, 0, 0, 0],
            [0, np.cos(rx), -np.sin(rx), 0],
            [0, np.sin(rx), np.cos(rx), 0],
            [0, 0, 0, 1]
        ], dtype=np.float32)
        
        rot_y = np.array([
            [np.cos(ry), 0, np.sin(ry), 0],
            [0, 1, 0, 0],
            [-np.sin(ry), 0, np.cos(ry), 0],
            [0, 0, 0, 1]
        ], dtype=np.float32)
        
        rot_z = np.array([
            [np.cos(rz), -np.sin(rz), 0, 0],
            [np.sin(rz), np.cos(rz), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1]
        ], dtype=np.float32)
        
        # Translation matrix
        tx, ty, tz = translation
        
        # Combine transformations
        mat = np.eye(4, dtype=np.float32)
        mat = mat @ rot_z @ rot_y @ rot_x  # Rotation order: Z-Y-X
        mat[0, 0] *= sx; mat[1, 1] *= sy; mat[2, 2] *= sz  # Scale
        mat[0:3, 3] = [tx, ty, tz]  # Translation
        
        mat_name = f"{self.prefix}_model_matrix"
        self.driver.create_tensor(mat_name, mat)
        return mat_name
        
    def transform_vertices(
        self,
        vertices_name: str,
        model_matrix_name: Optional[str] = None,
        view_matrix_name: Optional[str] = None,
        proj_matrix_name: Optional[str] = None
    ) -> str:
        """Transform vertices through model-view-projection pipeline"""
        vertices = self.driver.get_tensor(vertices_name)
        
        # Convert to homogeneous coordinates if needed
        if vertices.shape[-1] == 3:
            ones = np.ones((*vertices.shape[:-1], 1), dtype=vertices.dtype)
            vertices = np.concatenate([vertices, ones], axis=-1)
            homogeneous_name = f"{self.prefix}_homogeneous"
            self.driver.create_tensor(homogeneous_name, vertices)
            vertices_name = homogeneous_name
        
        result_name = vertices_name
        
        # Apply model transformation
        if model_matrix_name is not None:
            result_name = f"{self.prefix}_model_result"
            self.driver.matmul(vertices_name, model_matrix_name, out=result_name)
        
        # Apply view transformation
        if view_matrix_name is not None:
            view_result = f"{self.prefix}_view_result"
            self.driver.matmul(result_name, view_matrix_name, out=view_result)
            result_name = view_result
        
        # Apply projection transformation
        if proj_matrix_name is not None:
            proj_result = f"{self.prefix}_proj_result"
            self.driver.matmul(result_name, proj_matrix_name, out=proj_result)
            result_name = proj_result
            
            # Perform perspective divide if needed
            if self.driver.get_tensor(proj_matrix_name)[3, 2] != 0:  # Check if perspective
                verts = self.driver.get_tensor(result_name)
                w = verts[..., 3:]
                verts = verts / w  # Perspective divide
                self.driver.create_tensor(result_name, verts)
        
        return result_name
        
    def normal_matrix(self, model_matrix_name: str) -> str:
        """Create normal matrix (inverse transpose of 3x3 model matrix)"""
        model_mat = self.driver.get_tensor(model_matrix_name)
        normal_mat = np.linalg.inv(model_mat[:3, :3]).T
        
        mat_name = f"{self.prefix}_normal_matrix"
        self.driver.create_tensor(mat_name, normal_mat)
        return mat_name
        
    def transform_normals(
        self,
        normals_name: str,
        normal_matrix_name: str
    ) -> str:
        """Transform normal vectors using normal matrix"""
        result_name = f"{self.prefix}_transformed_normals"
        self.driver.matmul(normals_name, normal_matrix_name, out=result_name)
        
        # Renormalize
        normals = self.driver.get_tensor(result_name)
        normals = normals / np.linalg.norm(normals, axis=-1, keepdims=True)
        self.driver.create_tensor(result_name, normals)
        
        return result_name
