import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

data_id = 'misc'

# run ffm with different mapping scales
map_size = 2**np.arange(0, 3, 1)
trials = [1]

for trial in trials:
        for m in map_size:
                command = f'python residual_regression.py --config configs/{data_id}.txt '+\
                        f'--train_images=800 --num_epochs=100 --logdir=logs/{data_id}-tune-ffm/trial-{trial}/{m} --model ffm --ffm_map_scale {m} --batch_rays=60000'
                print(command)
                os.system(command)

'''
Make figure
'''
params = {'legend.fontsize': 12,
         'axes.labelsize': 12,
         'axes.titlesize': 13,
         'xtick.labelsize':10,
         'ytick.labelsize':10}
matplotlib.rcParams.update(params)

plt.figure(figsize=(5,4))
ax = plt.gca()
mean = np.zeros((len(map_size,)))
for i, length in enumerate(map_size):
        for trial in trials:
                result_path = f'logs/{data_id}-tune-ffm/trial-{trial}/{length}/ffm-L-1/result/test_psnr.npy'
                psnr = np.load(result_path)
                mse = 10**(-psnr/10)
                mean[i] += mse.mean()
        mean[i] /= len(trials)
        mean[i] = 10*np.log10(1./mean[i])

ax.plot(np.array(map_size), mean)
ax.set_xlim((map_size[0], map_size[-1]))
ax.set_xlabel('Mapping scale')
ax.set_title('(a) FFM hyperparameter tuning', y=-0.4)
ax.grid(True, which='major', alpha=.3)
ax.set_xscale('log', basex=2)
ax.set_ylabel('Mean PSNR')

plt.tight_layout()
plt.savefig('fig_ffm_sweep.png')
plt.show()
