| { |
| "save_path": "saved_standard_challenging_context32_nocond_cont_cont_all_cont_eval", |
| "model": { |
| "base_learning_rate": 8e-05, |
| "target": "latent_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion", |
| "params": { |
| "linear_start": 0.0015, |
| "linear_end": 0.0195, |
| "num_timesteps_cond": 1, |
| "log_every_t": 200, |
| "timesteps": 1000, |
| "first_stage_key": "image", |
| "cond_stage_key": "action_", |
| "scheduler_sampling_rate": 0.0, |
| "hybrid_key": "c_concat", |
| "image_size": [ |
| 64, |
| 48 |
| ], |
| "channels": 3, |
| "cond_stage_trainable": false, |
| "conditioning_key": "hybrid", |
| "monitor": "val/loss_simple_ema", |
| "unet_config": { |
| "target": "latent_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel", |
| "params": { |
| "image_size": [ |
| 64, |
| 48 |
| ], |
| "in_channels": 8, |
| "out_channels": 4, |
| "model_channels": 192, |
| "attention_resolutions": [ |
| 8, |
| 4, |
| 2 |
| ], |
| "num_res_blocks": 2, |
| "channel_mult": [ |
| 1, |
| 2, |
| 3, |
| 5 |
| ], |
| "num_head_channels": 32, |
| "use_spatial_transformer": false, |
| "transformer_depth": 1 |
| } |
| }, |
| "temporal_encoder_config": { |
| "target": "latent_diffusion.ldm.modules.encoders.temporal_encoder.TemporalEncoder", |
| "params": { |
| "input_channels": 6, |
| "hidden_size": 1024, |
| "num_layers": 1, |
| "dropout": 0.1, |
| "output_channels": 4, |
| "output_height": 48, |
| "output_width": 64 |
| } |
| }, |
| "first_stage_config": { |
| "target": "latent_diffusion.ldm.models.autoencoder.AutoencoderKL", |
| "params": { |
| "embed_dim": 4, |
| "monitor": "val/rec_loss", |
| "ddconfig": { |
| "double_z": true, |
| "z_channels": 4, |
| "resolution": 256, |
| "in_channels": 3, |
| "out_ch": 3, |
| "ch": 128, |
| "ch_mult": [ |
| 1, |
| 2, |
| 4, |
| 4 |
| ], |
| "num_res_blocks": 2, |
| "attn_resolutions": [], |
| "dropout": 0.0 |
| }, |
| "lossconfig": { |
| "target": "torch.nn.Identity" |
| } |
| } |
| }, |
| "cond_stage_config": "__is_unconditional__" |
| } |
| }, |
| "data": { |
| "target": "data.data_processing.datasets.DataModule", |
| "params": { |
| "batch_size": 8, |
| "num_workers": 1, |
| "wrap": false, |
| "shuffle": true, |
| "drop_last": true, |
| "pin_memory": true, |
| "prefetch_factor": 2, |
| "persistent_workers": true, |
| "train": { |
| "target": "data.data_processing.datasets.ActionsData", |
| "params": { |
| "data_csv_path": "desktop_sequences_filtered_with_desktop_1.5k.challenging.train.target_frames.csv", |
| "normalization": "standard", |
| "context_length": 32 |
| } |
| } |
| } |
| }, |
| "lightning": { |
| "trainer": { |
| "benchmark": false, |
| "max_epochs": 6400, |
| "limit_val_batches": 0, |
| "accelerator": "gpu", |
| "gpus": 1, |
| "accumulate_grad_batches": 999999, |
| "gradient_clip_val": 1, |
| "checkpoint_callback": true |
| } |
| } |
| } |