import unittest
import sys
import os

# Add the src directory to the Python path to allow imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))

from driver_api import VirtualGPUDriver

class TestVirtualGPUDriver(unittest.TestCase):

    def setUp(self):
        self.driver = VirtualGPUDriver()
        self.driver.initialize(num_chips=1, vram_size_gb=0.1) # Use small VRAM for testing

    def tearDown(self):
        self.driver.shutdown()

    def test_initialization_and_shutdown(self):
        self.assertTrue(self.driver.initialized)
        self.driver.shutdown()
        self.assertFalse(self.driver.initialized)
        self.driver.initialize(num_chips=1, vram_size_gb=0.1) # Re-initialize for other tests
        self.assertTrue(self.driver.initialized)

    def test_memory_allocation_and_free(self):
        size = 100
        addr = self.driver.allocate_memory(size)
        self.assertIsNotNone(addr)
        self.assertIn(addr, self.driver.memory_manager.allocated_blocks)
        self.driver.free_memory(addr)
        self.assertNotIn(addr, self.driver.memory_manager.allocated_blocks)

    def test_memory_write_and_read(self):
        size = 10
        addr = self.driver.allocate_memory(size)
        test_data = [i for i in range(size)]
        self.driver.write_memory(addr, test_data)
        read_data = self.driver.read_memory(addr, size)
        self.assertEqual(read_data, test_data)
        self.driver.free_memory(addr)

    def test_add_and_submit_commands(self):
        self.driver.add_command("test_command_1", arg1="value1")
        self.driver.add_command("test_command_2", arg2="value2")
        self.assertEqual(len(self.driver.command_processor.command_buffer), 2)
        
        # Submitting commands will clear the buffer
        results = self.driver.submit_commands()
        self.assertEqual(len(self.driver.command_processor.command_buffer), 0)
        self.assertIsNotNone(results)

    def test_graphics_api_buffer_creation(self):
        buffer_id = self.driver.create_buffer(1024, "vertex")
        self.assertIsNotNone(buffer_id)
        self.assertIn(buffer_id, self.driver.graphics_api.buffers)
        self.driver.delete_buffer(buffer_id)
        self.assertNotIn(buffer_id, self.driver.graphics_api.buffers)

    def test_graphics_api_buffer_data(self):
        buffer_id = self.driver.create_buffer(100, "vertex")
        data = [i for i in range(100)]
        self.driver.buffer_data(buffer_id, data)
        # Verify data by reading from memory manager (simulated)
        read_data = self.driver.read_memory(self.driver.graphics_api.buffers[buffer_id]["virtual_address"], 100)
        self.assertEqual(read_data, data)
        self.driver.delete_buffer(buffer_id)

    def test_graphics_api_shader_compilation_and_program_linking(self):
        vertex_shader_source = "attribute vec4 position; void main() { gl_Position = position; }"
        fragment_shader_source = "void main() { gl_FragColor = vec4(1.0, 0.0, 0.0, 1.0); }"

        vertex_shader = self.driver.compile_shader(vertex_shader_source, "vertex")
        fragment_shader = self.driver.compile_shader(fragment_shader_source, "fragment")
        
        self.assertIsNotNone(vertex_shader)
        self.assertIsNotNone(fragment_shader)

        program = self.driver.graphics_api.link_program(vertex_shader, fragment_shader)
        self.assertIsNotNone(program)
        self.assertTrue(program["linked"])

        self.driver.use_program(program)
        self.assertEqual(self.driver.graphics_api.current_program, program)

    def test_graphics_api_framebuffer_operations(self):
        fb = self.driver.create_framebuffer(64, 64)
        self.assertIsNotNone(fb)
        self.assertIn("color", fb)
        self.assertIn("depth", fb)

        self.driver.bind_framebuffer(fb)
        self.assertEqual(self.driver.graphics_api.current_framebuffer, fb)

        self.driver.clear_color(0.0, 0.0, 1.0, 1.0) # Clear to blue
        # In a real test, we'd read back the color buffer and verify content

        self.driver.clear_depth(1.0) # Clear depth
        # In a real test, we'd read back the depth buffer and verify content

if __name__ == '__main__':
    unittest.main()


