"""
Hyperrealistic voltage-based logic gates for digital simulation.
Each gate operates on analog voltages, with digital 1/0 determined by thresholding.
Gate switching speed is parameterized to match target transistor switching rates.
"""

import random

# Constants for voltage logic
VDD = 0.7  # High voltage (V)
VSS = 0.0  # Low voltage (V)
VTH = 0.35  # Threshold voltage (V)

# Gate switching delay (in seconds) to match fastest possible switching
# This should be the minimum possible, based on electron_speed.py calculation
from electron_speed import max_switch_freq
GATE_DELAY = 1 / max_switch_freq  # seconds per switch (theoretical limit)

class LogicGate:
    def __init__(self, vdd=VDD, vss=VSS, vth=VTH, delay=GATE_DELAY):
        self.vdd = vdd
        self.vss = vss
        self.vth = vth
        self.delay = delay

    def interpret(self, voltage):
        """Return digital 1 if voltage > Vth, else 0."""
        return 1 if voltage > self.vth else 0

    def voltage(self, bit):
        """Return voltage for digital bit."""
        return self.vdd if bit else self.vss

class NANDGate(LogicGate):
    def output(self, vin1, vin2):
        # Interpret inputs as digital
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        # NAND logic: output is high unless both inputs are high
        out_bit = 0 if (in1 and in2) else 1
        # Add random noise for realism
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

class ANDGate(LogicGate):
    def output(self, vin1, vin2):
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        out_bit = 1 if (in1 and in2) else 0
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

class ORGate(LogicGate):
    def output(self, vin1, vin2):
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        out_bit = 1 if (in1 or in2) else 0
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

class NOTGate(LogicGate):
    def output(self, vin):
        in_bit = self.interpret(vin)
        out_bit = 0 if in_bit else 1
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

# Example usage and test
if __name__ == "__main__":
    nand = NANDGate()
    andg = ANDGate()
    org = ORGate()
    notg = NOTGate()
    print("NAND(0.7, 0.7):", nand.output(0.7, 0.7))
    print("AND(0.7, 0.7):", andg.output(0.7, 0.7))
    print("OR(0.0, 0.7):", org.output(0.0, 0.7))
    print("NOT(0.7):", notg.output(0.7))
    print(f"Gate delay (s): {GATE_DELAY:.2e}")


# --- Combinational Logic ---
class XORGate(LogicGate):
    def output(self, vin1, vin2):
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        out_bit = 1 if (in1 != in2) else 0
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

class NORGate(LogicGate):
    def output(self, vin1, vin2):
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        out_bit = 0 if (in1 or in2) else 1
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

class XNORGate(LogicGate):
    def output(self, vin1, vin2):
        in1 = self.interpret(vin1)
        in2 = self.interpret(vin2)
        out_bit = 1 if (in1 == in2) else 0
        noise = random.gauss(0, 0.01 * self.vdd)
        return self.voltage(out_bit) + noise

# Example: 1-bit Full Adder (combinational logic)
class FullAdder:
    def __init__(self):
        self.xor1 = XORGate()
        self.xor2 = XORGate()
        self.and1 = ANDGate()
        self.and2 = ANDGate()
        self.or1 = ORGate()

    def output(self, a, b, cin):
        sum1 = self.xor1.output(a, b)
        sum_bit = self.xor2.output(sum1, cin)
        carry1 = self.and1.output(a, b)
        carry2 = self.and2.output(sum1, cin)
        cout = self.or1.output(carry1, carry2)
        return sum_bit, cout

# --- Sequential Logic ---
# SR, D, JK, T Flip-Flops (voltage-based, using gates)
class SRFlipFlop:
    def __init__(self):
        self.q = VSS
        self.nand1 = NANDGate()
        self.nand2 = NANDGate()

    def output(self, s, r):
        # s, r: voltages
        q_bar = self.nand1.output(s, self.q)
        self.q = self.nand2.output(r, q_bar)
        return self.q

class DFlipFlop:
    def __init__(self):
        self.sr = SRFlipFlop()

    def output(self, d, clk):
        # On rising clock, sample d
        s = d if clk > VTH else VSS
        r = NOTGate().output(d) if clk > VTH else VSS
        return self.sr.output(s, r)

class JKFlipFlop:
    def __init__(self):
        self.q = VSS
        self.j = None
        self.k = None
        self.nand1 = NANDGate()
        self.nand2 = NANDGate()
        self.nand3 = NANDGate()
        self.nand4 = NANDGate()

    def output(self, j, k, clk):
        # Simple JK: toggle on J=K=1, set/reset otherwise
        if clk > VTH:
            if j > VTH and k > VTH:
                self.q = VDD if self.q == VSS else VSS
            elif j > VTH:
                self.q = VDD
            elif k > VTH:
                self.q = VSS
        return self.q

class TFlipFlop:
    def __init__(self):
        self.q = VSS

    def output(self, t, clk):
        if clk > VTH and t > VTH:
            self.q = VDD if self.q == VSS else VSS
        return self.q

# Example: 2-bit Register (sequential logic)
class Register2Bit:
    def __init__(self):
        self.dff0 = DFlipFlop()
        self.dff1 = DFlipFlop()

    def output(self, d0, d1, clk):
        q0 = self.dff0.output(d0, clk)
        q1 = self.dff1.output(d1, clk)
        return q0, q1

# Example usage
if __name__ == "__main__":
    # ...existing code...
    xor = XORGate()
    print("XOR(0.7, 0.0):", xor.output(0.7, 0.0))
    fa = FullAdder()
    s, c = fa.output(0.7, 0.7, 0.0)
    print("FullAdder(1,1,0): sum=", s, "carry=", c)
    sr = SRFlipFlop()
    print("SRFlipFlop S=1, R=0:", sr.output(0.7, 0.0))
    dff = DFlipFlop()
    print("DFlipFlop D=1, CLK=1:", dff.output(0.7, 0.7))
    jk = JKFlipFlop()
    print("JKFlipFlop J=1, K=1, CLK=1:", jk.output(0.7, 0.7, 0.7))
    tff = TFlipFlop()
    print("TFlipFlop T=1, CLK=1:", tff.output(0.7, 0.7))
    reg = Register2Bit()
    print("Register2Bit D0=1, D1=0, CLK=1:", reg.output(0.7, 0.0, 0.7))


# --- Functional Units and Modules ---
# Arithmetic Logic Unit (ALU) - 1-bit (can be extended to n-bit)
class ALU1Bit:
    def __init__(self):
        self.andg = ANDGate()
        self.org = ORGate()
        self.xorg = XORGate()
        self.fadd = FullAdder()

    def operate(self, a, b, cin, op):
        """
        op: 2-bit operation selector
        00 = AND, 01 = OR, 10 = ADD, 11 = XOR
        Returns (result, carry_out)
        """
        if op == 0b00:
            return self.andg.output(a, b), 0.0
        elif op == 0b01:
            return self.org.output(a, b), 0.0
        elif op == 0b10:
            s, c = self.fadd.output(a, b, cin)
            return s, c
        elif op == 0b11:
            return self.xorg.output(a, b), 0.0
        else:
            raise ValueError("Invalid ALU op")

# 2-bit ALU (example of module composition)
class ALU2Bit:
    def __init__(self):
        self.alu0 = ALU1Bit()
        self.alu1 = ALU1Bit()

    def operate(self, a0, a1, b0, b1, cin, op):
        # Least significant bit
        r0, c0 = self.alu0.operate(a0, b0, cin, op)
        # Most significant bit
        r1, c1 = self.alu1.operate(a1, b1, c0, op)
        return (r0, r1), c1

# 2-bit Counter (using T flip-flops)
class Counter2Bit:
    def __init__(self):
        self.tff0 = TFlipFlop()
        self.tff1 = TFlipFlop()

    def tick(self, clk):
        q0 = self.tff0.output(VDD, clk)
        q1 = self.tff1.output(q0, clk)
        return self.tff0.q, self.tff1.q

# 2x2-bit Register File (2 registers, 2 bits each)
class RegisterFile2x2:
    def __init__(self):
        self.reg0 = Register2Bit()
        self.reg1 = Register2Bit()
        self.sel = 0  # select register 0 or 1

    def write(self, d0, d1, clk, sel):
        if sel == 0:
            self.reg0.output(d0, d1, clk)
        else:
            self.reg1.output(d0, d1, clk)

    def read(self, sel):
        if sel == 0:
            return self.reg0.dff0.sr.q, self.reg0.dff1.sr.q
        else:
            return self.reg1.dff0.sr.q, self.reg1.dff1.sr.q

# Example usage of functional units
if __name__ == "__main__":
    # ...existing code...
    alu = ALU1Bit()
    res, cout = alu.operate(0.7, 0.0, 0.0, 0b10)
    print("ALU1Bit ADD 1+0: result=", res, "carry=", cout)
    alu2 = ALU2Bit()
    (r0, r1), c = alu2.operate(0.7, 0.0, 0.7, 0.7, 0.0, 0b10)
    print("ALU2Bit ADD (10)+(11): result=", (r0, r1), "carry=", c)
    counter = Counter2Bit()
    print("Counter2Bit tick 1:", counter.tick(0.7))
    print("Counter2Bit tick 2:", counter.tick(0.7))
    regfile = RegisterFile2x2()
    regfile.write(0.7, 0.0, 0.7, 0)
    regfile.write(0.0, 0.7, 0.7, 1)
    print("RegisterFile2x2 read reg0:", regfile.read(0))
    print("RegisterFile2x2 read reg1:", regfile.read(1))


# --- Control Unit, Registers, and Memory Management Units ---

# Simple Control Unit (Finite State Machine for ALU operations)
class ControlUnit:
    def __init__(self):
        self.state = 0
        self.opcode = 0b00  # default operation

    def set_opcode(self, opcode):
        self.opcode = opcode

    def next_state(self):
        self.state = (self.state + 1) % 4
        return self.state

    def get_control_signals(self):
        # Example: output ALU op and register select
        reg_sel = self.state % 2
        return {'alu_op': self.opcode, 'reg_sel': reg_sel}

# General Purpose Register (n-bit, here 2-bit for demo)
class GeneralPurposeRegister:
    def __init__(self, bits=2):
        self.bits = bits
        self.dffs = [DFlipFlop() for _ in range(bits)]

    def write(self, data, clk):
        for i in range(self.bits):
            self.dffs[i].output(data[i], clk)

    def read(self):
        return tuple(self.dffs[i].sr.q for i in range(self.bits))

# Simple Memory Management Unit (MMU) - address decode and register file access
class SimpleMMU:
    def __init__(self, num_registers=2, bits=2):
        self.registers = [GeneralPurposeRegister(bits) for _ in range(num_registers)]

    def write(self, addr, data, clk):
        if 0 <= addr < len(self.registers):
            self.registers[addr].write(data, clk)

    def read(self, addr):
        if 0 <= addr < len(self.registers):
            return self.registers[addr].read()
        return None

# Example usage of control and memory units
if __name__ == "__main__":
    # ...existing code...
    cu = ControlUnit()
    cu.set_opcode(0b10)  # ADD
    print("ControlUnit state:", cu.next_state(), cu.get_control_signals())
    gpr = GeneralPurposeRegister(bits=2)
    gpr.write([0.7, 0.0], 0.7)
    print("GeneralPurposeRegister read:", gpr.read())
    mmu = SimpleMMU(num_registers=2, bits=2)
    mmu.write(0, [0.7, 0.0], 0.7)
    mmu.write(1, [0.0, 0.7], 0.7)
    print("SimpleMMU read reg0:", mmu.read(0))
    print("SimpleMMU read reg1:", mmu.read(1))
