import torch
import numpy as np
from tqdm import tqdm as tqdm
import os
import imageio

from load_approx_res import ApproxResSurfaceDataset
from model import *
from rd_wrapper import rd_wrapper

# Load network 
data_id = 'chicken'
data_dir = f'data/{data_id}'
output_dir = f'GP_training'
data_type = 'blender'

# Initialize training dataset
obj_path = f'{data_dir}/{data_id}-sh.obj'
dsize = [512, 512]

partition = {'train': [f'./train/r_{i}' for i in range(18,19)], 
             'test': [f'./test/r_{i}' for i in range(1)]}
params = {'shuffle': False,
          'num_workers': 1}
train_set = ApproxResSurfaceDataset(data_type, data_dir, obj_path, partition['train'], 
                'transforms_train.json' if data_type=='blender' else 'cameras_train.json', L=0)
train_generator = torch.utils.data.DataLoader(train_set, **params)
test_set = ApproxResSurfaceDataset(data_type, data_dir, obj_path, partition['test'], 
                'transforms_test.json' if data_type=='blender' else 'cameras_test.json', L=-1)
test_generator = torch.utils.data.DataLoader(test_set, **params)

# Compute residual training data
os.makedirs(output_dir, exist_ok=True)

print("Generating training residuals")
with torch.no_grad():
        for i, (x, residual, mask, approx) in enumerate(tqdm(train_generator)):
                with open(os.path.join(output_dir, f'{i}.npy'), 'wb') as f:
                        np.save(f, x)
                        np.save(f, residual)
                        np.save(f, mask)

print("Generating testing residuals")
with torch.no_grad():
        for i, (x, residual, mask, approx) in enumerate(tqdm(test_generator)):
                with open(os.path.join(output_dir, f'test{i}.npy'), 'wb') as f:
                        np.save(f, x)
                        np.save(f, residual)
                        np.save(f, mask)