|
|
|
|
|
import torch |
|
|
from MultiTaskConvLSTM import ConvLSTMNetwork |
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from tqdm.auto import tqdm |
|
|
from utils import ( |
|
|
mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation, |
|
|
spearman_correlation, percentage_error, percentage_bias, |
|
|
kendall_tau, spatial_correlation |
|
|
) |
|
|
import torch.optim as optim |
|
|
|
|
|
|
|
|
device = 'cpu' |
|
|
|
|
|
height = 81 |
|
|
width = 97 |
|
|
|
|
|
set_lookback = 1 |
|
|
set_forecast_horizon = 1 |
|
|
|
|
|
|
|
|
batch_size = 16 |
|
|
time_steps_out = set_forecast_horizon |
|
|
channels = 14 |
|
|
|
|
|
|
|
|
variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'UV visible albedo for direct radiation (climatological)', 'Total column rain water', 'Volumetric soil water layer 1', 'Leaf area index, high vegetation', 'Leaf area index, low vegetation', 'Forecast surface roughness', 'Total precipitation', 'Time-integrated surface latent heat net flux', 'Evaporation'] |
|
|
|
|
|
|
|
|
model = ConvLSTMNetwork( |
|
|
input_dim=14 * set_lookback, |
|
|
hidden_dims=[14, 32, 64], |
|
|
kernel_size=(3,3), |
|
|
num_layers=3, |
|
|
output_channels=64 * set_forecast_horizon, |
|
|
batch_first=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
loss_fn = nn.MSELoss() |
|
|
bce_loss_fn = nn.BCELoss() |
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr = 0.005) |
|
|
|
|
|
checkpoint = torch.load("MultiTaskConvLSTM_veg_variables.pth", map_location = device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
print("Model loaded successfully") |
|
|
|
|
|
|
|
|
threshold = 0.1 |
|
|
precip_index = 10 |
|
|
|
|
|
def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width): |
|
|
""" |
|
|
Evaluate the model on the test set for both regression and classification tasks. |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_reg_loss = 0.0 |
|
|
test_class_loss = 0.0 |
|
|
test_total_loss = 0.0 |
|
|
|
|
|
y_true_reg = [] |
|
|
y_pred_reg = [] |
|
|
|
|
|
y_pred_reg2 = [] |
|
|
|
|
|
y_true_class = [] |
|
|
y_pred_class = [] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"): |
|
|
|
|
|
X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device) |
|
|
|
|
|
|
|
|
batch_size, time_steps_in, channels_in, grid_points = X_test.shape |
|
|
batch_size, time_steps_out, channels_out, grid_points = y_test.shape |
|
|
X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width) |
|
|
y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width) |
|
|
y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width) |
|
|
|
|
|
|
|
|
regression_output, classification_output = model(X_test) |
|
|
|
|
|
classification_predictions = (classification_output > 0.7).float() |
|
|
|
|
|
|
|
|
reg_loss = reg_loss_fn(regression_output, y_test) |
|
|
|
|
|
|
|
|
class_loss = class_loss_fn(classification_output, y_zero_test) |
|
|
|
|
|
|
|
|
total_loss = reg_loss + class_loss |
|
|
|
|
|
regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions) |
|
|
|
|
|
|
|
|
test_reg_loss += reg_loss.item() * X_test.size(0) |
|
|
test_class_loss += class_loss.item() * X_test.size(0) |
|
|
test_total_loss += total_loss.item() * X_test.size(0) |
|
|
|
|
|
|
|
|
y_true_reg.append(y_test.cpu()) |
|
|
y_pred_reg.append(regression_output.cpu()) |
|
|
y_pred_reg2.append(regression_output2.cpu()) |
|
|
y_true_class.append(y_zero_test.cpu()) |
|
|
y_pred_class.append(classification_output.cpu()) |
|
|
|
|
|
|
|
|
test_reg_loss /= len(test_loader) |
|
|
test_class_loss /= len(test_loader) |
|
|
test_total_loss /= len(test_loader) |
|
|
|
|
|
print(f"Test Regression Loss: {test_reg_loss:.16f}") |
|
|
print(f"Test Classification Loss: {test_class_loss:.16f}") |
|
|
print(f"Test Total Loss: {test_total_loss:.16f}") |
|
|
|
|
|
y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() |
|
|
y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() |
|
|
y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() |
|
|
y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() |
|
|
|
|
|
|
|
|
regression_metrics = { |
|
|
"MSE": mse(y_true_reg_flat, y_pred_reg_flat), |
|
|
"MAE": mae(y_true_reg_flat, y_pred_reg_flat), |
|
|
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat), |
|
|
"R2": r2_score(y_true_reg_flat, y_pred_reg_flat), |
|
|
"Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat), |
|
|
"Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat), |
|
|
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat), |
|
|
"Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat), |
|
|
"Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat), |
|
|
"Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)} |
|
|
|
|
|
print("\nRegression Metrics:") |
|
|
for metric, value in regression_metrics.items(): |
|
|
print(f"{metric}: {value:.16f}") |
|
|
|
|
|
|
|
|
|
|
|
classification_metrics = { |
|
|
"Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
|
|
"Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
|
|
"Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
|
|
"F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
|
|
"ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat), |
|
|
} |
|
|
|
|
|
print("\nClassification Metrics:") |
|
|
for metric, value in classification_metrics.items(): |
|
|
print(f"{metric}: {value:.16f}") |
|
|
|
|
|
torch.save({ |
|
|
'y_true_reg': y_true_reg_flat, |
|
|
'y_pred_reg': y_pred_reg_flat, |
|
|
'y_true_class': y_true_class_flat, |
|
|
'y_pred_class': y_pred_class_flat, |
|
|
}, 'results') |
|
|
|
|
|
return test_total_loss, regression_metrics, classification_metrics |
|
|
|
|
|
|
|
|
""" |
|
|
EXPECTED DATALOADER BATCH FORMAT (normalized_test_data): |
|
|
|
|
|
Each batch must be a tuple: (X_batch, y_batch, y_zero_batch) |
|
|
|
|
|
X_batch contains the previous hours variables. y_batch contains the next hour's precipitation. |
|
|
y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and |
|
|
1 for precipitation >0.1mm. |
|
|
|
|
|
Shapes BEFORE reshaping inside `evaluate`: |
|
|
X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857 |
|
|
y_batch: (B, T_out, C_out, G) |
|
|
y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets |
|
|
|
|
|
If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference. |
|
|
|
|
|
DTypes: |
|
|
X_batch, y_batch: torch.float32 |
|
|
y_zero_batch: torch.float32 (will be used with BCELoss) |
|
|
|
|
|
Reshaping done in 'evaluate': |
|
|
X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97) |
|
|
y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97) |
|
|
y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W) |
|
|
|
|
|
Model input: |
|
|
model expects X_test shaped (B, T_in, input_dim, H, W) |
|
|
where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9) |
|
|
|
|
|
Notes: |
|
|
• Make sure G == H*W (i.e., 7857 for 81x97). |
|
|
• C_out for precipitation should be 1 (one target channel), and y_zero_batch |
|
|
is the 0/1 mask for “zero precipitation” at each pixel & time. |
|
|
• y_zero_batch should be probabilities/labels in {0,1} for BCELoss. |
|
|
""" |
|
|
|
|
|
normalized_test_data = torch.load("data/normalized_test_data_veg_input.pth") |
|
|
|
|
|
|
|
|
test_total_loss, regression_metrics, classification_metrics = evaluate( |
|
|
model=model, |
|
|
test_loader=normalized_test_data, |
|
|
reg_loss_fn=loss_fn, |
|
|
class_loss_fn=bce_loss_fn, |
|
|
device=device, |
|
|
variable_names=variable_names, |
|
|
height=height, |
|
|
width=width, |
|
|
) |