import os import random import uuid # Support both in-repo and standalone imports try: # In-repo imports (when running from OpenEnv repository) from openenv.core.env_server import Environment from ..models import WildfireAction, WildfireObservation, WildfireState except ImportError: # Standalone imports (when environment is standalone with openenv-core from pip) from openenv_core.env_server import Environment from wildfire_env.models import WildfireAction, WildfireObservation, WildfireState # Helpers DIRS_8 = { "N": (0, -1), "NE": (1, -1), "E": (1, 0), "SE": (1, 1), "S": (0, 1), "SW": (-1, 1), "W": (-1, 0), "NW": (-1, -1), "CALM": (0, 0), } def idx(x: int, y: int, w: int) -> int: # Defensive type conversion to ensure all parameters are integers x, y, w = int(x), int(y), int(w) return y * w + x def in_bounds(x: int, y: int, w: int, h: int) -> bool: # Defensive type conversion to ensure all parameters are integers x, y, w, h = int(x), int(y), int(w), int(h) return 0 <= x < w and 0 <= y < h class WildfireEnvironment(Environment): """ Weather-aware wildfire simulation. Grid encodings: 0 = ash (burned out) 1 = fuel / vegetation 2 = burning 3 = firebreak 4 = watered / damp Each step: - agent acts (water/break/wait) - burning spreads to neighbors with wind + humidity effects - burning cells burn for multiple ticks, then become ash """ def __init__( self, width: int = 32, height: int = 32, base_ignite_prob: float = 0.30, wind_bias: float = 0.20, # kept for compatibility (not directly used in B model) diag_factor: float = 0.7, # kept for compatibility (not directly used in B model) humidity: float = 0.25, init_sources: int = 2, seed: int = 3407, max_steps: int = 128, water_capacity: int = 8, # ↓ encourage strategic water use break_capacity: int = 50, ): super().__init__() # --- Env-var overrides (optional) --- width = int(os.environ.get("WILDFIRE_WIDTH", width)) height = int(os.environ.get("WILDFIRE_HEIGHT", height)) humidity = float(os.environ.get("WILDFIRE_HUMIDITY", humidity)) forced_wind = os.environ.get("WILDFIRE_WIND", None) # Store config (ensure integers) self.w = int(width) self.h = int(height) self.base_ignite_prob = base_ignite_prob self.wind_bias = wind_bias self.diag_factor = diag_factor self.init_humidity = humidity self.init_sources = init_sources self.rng = random.Random(seed) self.max_steps = max_steps self.init_water = water_capacity self.init_breaks = break_capacity self.forced_wind = forced_wind # burn lifetime in ticks (balanced model) self.burn_lifetime = 3 # Initialize state with minimal defaults (will be properly set in reset()) # We can't use WildfireState() directly due to Pydantic/dataclass conflicts, # so we'll initialize it in reset() and handle None case in state property self._state: WildfireState | None = None # --- Core API --- def reset(self) -> WildfireObservation: # Ensure w and h are integers (defensive type conversion) w, h = int(self.w), int(self.h) # Start with all fuel grid = [1] * (w * h) # Wind (forced if provided) if self.forced_wind and self.forced_wind in DIRS_8: wind_dir = self.forced_wind else: wind_dir = self.rng.choice(list(DIRS_8.keys())) # Humidity small variation around init humidity = min(1.0, max(0.0, self.init_humidity + self.rng.uniform(-0.05, 0.05))) # Place initial fires for _ in range(self.init_sources): x = self.rng.randrange(w) y = self.rng.randrange(h) i = idx(x, y, w) # Safety check: ensure index is within grid bounds if 0 <= i < len(grid): grid[i] = 2 # Initialize burn timers before creating state burn_timers = [0] * (w * h) # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility self._state = WildfireState.model_construct( episode_id=str(uuid.uuid4()), step_count=0, total_burned=0, total_extinguished=0, last_action="reset", width=w, height=h, wind_dir=wind_dir, humidity=humidity, remaining_water=self.init_water, remaining_breaks=self.init_breaks, grid=grid, burn_timers=burn_timers, ) obs = self._make_observation(reward_hint=0.0) return obs def step(self, action: WildfireAction) -> WildfireObservation: st = self._state reward = 0.0 # --- Agent action effects --- if ( action.action == "water" and st.remaining_water > 0 and action.x is not None and action.y is not None ): reward += self._apply_water(action.x, action.y) elif ( action.action == "break" and st.remaining_breaks > 0 and action.x is not None and action.y is not None ): reward += self._apply_break(action.x, action.y) elif action.action == "wait": pass else: reward -= 0.05 # invalid or exhausted resources # --- Natural fire dynamics --- prev_burning = self._burning_count() prev_burned = sum(1 for v in st.grid if v == 0) newly_burned = self._spread_fire() new_burning = self._burning_count() now_burned = sum(1 for v in st.grid if v == 0) st.total_burned += newly_burned st.step_count += 1 st.last_action = action.action # --- Spread vs containment shaping --- spread_delta = new_burning - prev_burning burned_delta = now_burned - prev_burned # Strong penalty for spread if spread_delta > 0: reward -= 0.15 * spread_delta # 🔥 focus on containment elif spread_delta < 0: reward += 0.10 * abs(spread_delta) # reward shrinkage # Mild penalty for newly burned cells (area loss) if burned_delta > 0: reward -= 0.05 * burned_delta # Small time penalty to prefer fast control reward -= 0.01 done = self._is_done() # --- End of episode bonuses --- if done: saved_ratio = self._saved_cells() / (self.w * self.h) burned_ratio = now_burned / (self.w * self.h) burning_left = self._burning_count() # Big containment bonus if burning_left == 0: reward += 0.5 + 0.5 * saved_ratio # Fallback proportional reward reward += 0.2 * (1.0 - burned_ratio) obs = self._make_observation(reward_hint=reward) obs.done = done obs.reward = reward return obs # --- Internal mechanics --- def _apply_water(self, x: int, y: int) -> float: st = self._state # Ensure x and y are integers (defensive type conversion) x, y = int(x), int(y) if not in_bounds(x, y, self.w, self.h): return -0.05 # Strong penalty if no water left if st.remaining_water <= 0: return -0.5 i = idx(x, y, self.w) # Safety check: ensure index is within grid bounds if i < 0 or i >= len(st.grid): return -0.05 reward = 0.0 if st.grid[i] == 2: st.grid[i] = 4 # extinguish & dampen st.burn_timers[i] = 0 st.total_extinguished += 1 reward += 0.25 elif st.grid[i] == 1: st.grid[i] = 4 # dampen fuel (mild penalty to avoid spamming) st.burn_timers[i] = 0 reward -= 0.10 elif st.grid[i] == 4: # redundant watering reward -= 0.05 else: # watering ash/break gives slight penalty reward -= 0.05 st.remaining_water -= 1 return reward def _apply_break(self, x: int, y: int) -> float: st = self._state # Ensure x and y are integers (defensive type conversion) x, y = int(x), int(y) if not in_bounds(x, y, self.w, self.h): return -0.05 i = idx(x, y, self.w) # Safety check: ensure index is within grid bounds if i < 0 or i >= len(st.grid): return -0.05 reward = 0.0 if st.grid[i] in (1, 4): st.grid[i] = 3 st.burn_timers[i] = 0 reward += 0.15 # slightly more than before to make firebreaks attractive elif st.grid[i] == 2: st.grid[i] = 3 st.burn_timers[i] = 0 reward -= 0.02 elif st.grid[i] == 3: reward -= 0.01 else: reward -= 0.02 st.remaining_breaks -= 1 return reward def _spread_fire(self) -> int: """ Balanced wildfire spread model: - burning cells persist for multiple ticks before turning to ash - 8-direction spread (diagonals weaker) - wind accelerates in wind direction, weakens upwind - humidity suppresses ignition probability - water (4) is IMMUNE to ignition while damp and reverts to fuel after several ticks """ st = self._state new_grid = st.grid[:] newly_burned = 0 # Ensure w and h are integers (defensive type conversion) w, h = int(self.w), int(self.h) # 8-neighbor model neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (1, -1), (-1, 1), (1, 1)] wx, wy = DIRS_8.get(st.wind_dir, (0, 0)) base = self.base_ignite_prob humidity_factor = (1.0 - st.humidity) ignite_flags = [False] * (w * h) # First pass: evaluate ignitions, increment burn timers for y in range(h): for x in range(w): i = idx(x, y, w) # Safety check: ensure index is within grid bounds if i < 0 or i >= len(st.grid): continue cell = st.grid[i] if cell == 2: # burning st.burn_timers[i] += 1 for dx, dy in neighbors: nx, ny = x + dx, y + dy if not in_bounds(nx, ny, w, h): continue ni = idx(nx, ny, w) # Safety check: ensure neighbor index is within grid bounds if ni < 0 or ni >= len(st.grid): continue target = st.grid[ni] # Only fuel or water/damp can be candidates, but cells with code 4 (watered/damp) are immune to ignition if target == 4: # Watered/damp cells (code 4) do not ignite at all while in this state continue if target != 1: continue # Wind multiplier if (dx, dy) == (wx, wy): wind_mult = 2.0 elif (dx, dy) == (-wx, -wy): wind_mult = 0.5 else: wind_mult = 1.0 # Diagonals weaker diag_mult = 0.6 if (dx != 0 and dy != 0) else 1.0 p = base * humidity_factor * wind_mult * diag_mult p = max(0.0, min(1.0, p)) if self.rng.random() < p: # Safety check: ensure ni is within ignite_flags bounds if 0 <= ni < len(ignite_flags): ignite_flags[ni] = True # Second pass: apply transitions for i, cell in enumerate(st.grid): # Safety check: ensure index is within bounds for all arrays if i < 0 or i >= len(new_grid) or i >= len(st.burn_timers): continue if cell == 2: # burns for burn_lifetime ticks before turning to ash if st.burn_timers[i] >= self.burn_lifetime: new_grid[i] = 0 # ash newly_burned += 1 else: new_grid[i] = 2 # keep burning elif i < len(ignite_flags) and ignite_flags[i] and new_grid[i] == 1: new_grid[i] = 2 st.burn_timers[i] = 0 elif cell == 4: # Water stays damp for several ticks before reverting to fuel st.burn_timers[i] += 1 if st.burn_timers[i] >= 6: # was 3; extend to make water useful new_grid[i] = 1 st.grid = new_grid return newly_burned def _burning_count(self) -> int: return sum(1 for v in self._state.grid if v == 2) def _saved_cells(self) -> int: # cells not turned to ash (includes fuel, burning, break, water) return sum(1 for v in self._state.grid if v in (1, 2, 3, 4)) def _is_done(self) -> bool: return self._burning_count() == 0 or self._state.step_count >= self.max_steps def _make_observation(self, reward_hint: float = 0.0) -> WildfireObservation: st = self._state burning = self._burning_count() burned = sum(1 for v in st.grid if v == 0) # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility return WildfireObservation.model_construct( grid=st.grid[:], width=self.w, height=self.h, step=st.step_count, wind_dir=st.wind_dir, humidity=st.humidity, burning_count=burning, remaining_water=st.remaining_water, # ✅ new remaining_breaks=st.remaining_breaks, # ✅ new burned_count=burned, reward_hint=reward_hint, ) # --- Required abstract property implementation --- @property def state(self) -> WildfireState: """Return the current environment state.""" if self._state is None: # Initialize with minimal defaults if accessed before reset() # Use model_construct to bypass Pydantic validation for dataclass/Pydantic compatibility self._state = WildfireState.model_construct( episode_id="", step_count=0, total_burned=0, total_extinguished=0, last_action="reset", width=0, height=0, wind_dir="CALM", humidity=0.25, remaining_water=self.init_water, remaining_breaks=self.init_breaks, grid=[], burn_timers=[], ) return self._state