Tharun156 commited on
Commit
f7400bf
Β·
verified Β·
1 Parent(s): 859f17b

Upload 149 files

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. configs/beat2_rvqvae.yaml +134 -0
  3. configs/diffuser_rvqvae_128.yaml +96 -0
  4. configs/model_config.yaml +71 -0
  5. configs/sc_model_config.yaml +37 -0
  6. configs/sc_model_holistic_config.yaml +37 -0
  7. configs/sc_reflow_model_config.yaml +37 -0
  8. configs/shortcut.yaml +96 -0
  9. configs/shortcut_hf.yaml +96 -0
  10. configs/shortcut_holistic.yaml +96 -0
  11. configs/shortcut_reflow.yaml +96 -0
  12. configs/shortcut_reflow_test.yaml +96 -0
  13. configs/shortcut_rvqvae_128.yaml +96 -0
  14. configs/shortcut_rvqvae_128_hf.yaml +96 -0
  15. dataloaders/__pycache__/beat_sep_single.cpython-312.pyc +0 -0
  16. dataloaders/__pycache__/build_vocab.cpython-312.pyc +0 -0
  17. dataloaders/__pycache__/data_tools.cpython-312.pyc +0 -0
  18. dataloaders/beat_dataset_new.py +373 -0
  19. dataloaders/beat_sep.py +772 -0
  20. dataloaders/beat_sep_lower.py +430 -0
  21. dataloaders/beat_sep_single.py +693 -0
  22. dataloaders/beat_smplx2020.py +763 -0
  23. dataloaders/build_vocab.py +199 -0
  24. dataloaders/data_tools.py +1756 -0
  25. dataloaders/mix_sep.py +301 -0
  26. dataloaders/pymo/Quaternions.py +468 -0
  27. dataloaders/pymo/__init__.py +0 -0
  28. dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc +0 -0
  29. dataloaders/pymo/__pycache__/__init__.cpython-312.pyc +0 -0
  30. dataloaders/pymo/__pycache__/data.cpython-312.pyc +0 -0
  31. dataloaders/pymo/__pycache__/parsers.cpython-312.pyc +0 -0
  32. dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc +0 -0
  33. dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc +0 -0
  34. dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc +0 -0
  35. dataloaders/pymo/data.py +53 -0
  36. dataloaders/pymo/features.py +43 -0
  37. dataloaders/pymo/parsers.py +274 -0
  38. dataloaders/pymo/preprocessing.py +726 -0
  39. dataloaders/pymo/rotation_tools.py +153 -0
  40. dataloaders/pymo/rotation_tools.py! +69 -0
  41. dataloaders/pymo/viz_tools.py +236 -0
  42. dataloaders/pymo/writers.py +55 -0
  43. dataloaders/utils/__pycache__/audio_features.cpython-312.pyc +0 -0
  44. dataloaders/utils/__pycache__/other_tools.cpython-312.pyc +0 -0
  45. dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc +0 -0
  46. dataloaders/utils/audio_features.py +80 -0
  47. dataloaders/utils/data_sample.py +175 -0
  48. dataloaders/utils/mis_features.py +64 -0
  49. dataloaders/utils/motion_rep_transfer.py +236 -0
  50. dataloaders/utils/other_tools.py +748 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/examples/2_scott_0_1_1.wav filter=lfs diff=lfs merge=lfs -text
37
+ demo/examples/2_scott_0_2_2.wav filter=lfs diff=lfs merge=lfs -text
38
+ demo/examples/2_scott_0_3_3.wav filter=lfs diff=lfs merge=lfs -text
39
+ demo/examples/2_scott_0_4_4.wav filter=lfs diff=lfs merge=lfs -text
40
+ demo/examples/2_scott_0_5_5.wav filter=lfs diff=lfs merge=lfs -text
configs/beat2_rvqvae.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
8
+ e_path: weights/AESKConv_240_100.bin
9
+ eval_model: motion_representation
10
+ e_name: VAESKConv
11
+ test_ckpt: ./outputs/audio2pose/custom/0112_001634_emage/last_500.bin
12
+ data_path_1: ./datasets/hub/
13
+
14
+ vae_test_len: 32
15
+ vae_test_dim: 330
16
+ vae_test_stride: 20
17
+ vae_length: 240
18
+ vae_codebook_size: 256
19
+ vae_layer: 4
20
+ vae_grow: [1,1,2,1]
21
+ variational: False
22
+
23
+ # data config
24
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] #[2]
25
+ additional_data: False
26
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_rvqvae/
27
+ dataset: mix_sep
28
+ new_cache: True
29
+ use_amass: False
30
+ # motion config
31
+ ori_joints: beat_smplx_joints
32
+ tar_joints: beat_smplx_full
33
+ pose_rep: smplxflame_30
34
+ pose_norm: False
35
+ pose_fps: 30
36
+ rot6d: True
37
+ pre_frames: 4
38
+ pose_dims: 330
39
+ pose_length: 64
40
+ stride: 20
41
+ test_length: 64
42
+ motion_f: 256
43
+ m_pre_encoder: null
44
+ m_encoder: null
45
+ m_fix_pre: False
46
+
47
+ # audio config
48
+ audio_rep: onset+amplitude
49
+ audio_sr: 16000
50
+ audio_fps: 16000
51
+ audio_norm: False
52
+ audio_f: 256
53
+ # a_pre_encoder: tcn_camn
54
+ # a_encoder: none
55
+ # a_fix_pre: False
56
+
57
+ # text config
58
+ word_rep: textgrid
59
+ word_index_num: 11195
60
+ word_dims: 300
61
+ freeze_wordembed: False
62
+ word_f: 256
63
+ t_pre_encoder: fasttext
64
+ t_encoder: null
65
+ t_fix_pre: False
66
+
67
+ # facial config
68
+ facial_rep: smplxflame_30
69
+ facial_dims: 100
70
+ facial_norm: False
71
+ facial_f: 0
72
+ f_pre_encoder: null
73
+ f_encoder: null
74
+ f_fix_pre: False
75
+
76
+ # speaker config
77
+ id_rep: onehot
78
+ speaker_f: 0
79
+
80
+ # model config
81
+ batch_size: 80 #80
82
+ # warmup_epochs: 1
83
+ # warmup_lr: 1e-6
84
+ lr_base: 4e-4
85
+ model: motion_representation
86
+ g_name: VQVAEConvZero
87
+ trainer: ae_total
88
+ hidden_size: 768
89
+ n_layer: 1
90
+
91
+ rec_weight: 1
92
+ grad_norm: 0.99
93
+ epochs: 200
94
+ test_period: 20
95
+ ll: 3
96
+ lf: 3
97
+ lu: 3
98
+ lh: 3
99
+ cl: 1
100
+ cf: 0
101
+ cu: 1
102
+ ch: 1
103
+
104
+
105
+
106
+ #below is vavae config, copy from QPGESTURE
107
+ #Codebook Configs
108
+ levels: 1
109
+ downs_t: [3]
110
+ strides_t : [2]
111
+ emb_width : 512
112
+ l_bins : 512
113
+ l_mu : 0.99
114
+ commit : 0.1
115
+ hvqvae_multipliers : [1]
116
+ width: 512
117
+ depth: 3
118
+ m_conv : 1.0
119
+ dilation_growth_rate : 3
120
+ sample_length: 80
121
+ use_bottleneck: True
122
+ joint_channel: 6
123
+ # depth: 3
124
+ # width: 128
125
+ # m_conv: 1.0
126
+ # dilation_growth_rate: 1
127
+ # dilation_cycle: None
128
+ vel: 1 # 1 -> 0
129
+ acc: 1 # 1 -> 0
130
+ vqvae_reverse_decoder_dilation: True
131
+
132
+
133
+ ## below is special for emage
134
+ rec_pos_weight : 1.0
configs/diffuser_rvqvae_128.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_diffusion.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_lower
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: diffuser_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/model_config.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_name: GestureDiffuse
3
+ g_name: GestureDiffusion
4
+ do_classifier_free_guidance: False
5
+ guidance_scale: 1.5
6
+
7
+ denoiser:
8
+ target: models.denoiser.GestureDenoiser
9
+ params:
10
+ input_dim: 128
11
+ latent_dim: 256
12
+ ff_size: 1024
13
+ num_layers: 8
14
+ num_heads: 4
15
+ dropout: 0.1
16
+ activation: "gelu"
17
+ n_seed: 8
18
+ flip_sin_to_cos: True
19
+ freq_shift: 0.0
20
+
21
+
22
+
23
+ modality_encoder:
24
+ target: models.modality_encoder.ModalityEncoder
25
+ params:
26
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
27
+ t_fix_pre: False
28
+ audio_dim: 256
29
+ audio_in: 2
30
+ raw_audio: False
31
+ latent_dim: 256
32
+ audio_fps: 30
33
+
34
+
35
+ scheduler:
36
+ target: diffusers.DDIMScheduler
37
+ num_inference_steps: 20
38
+ eta: 0.0
39
+ params:
40
+ num_train_timesteps: 1000
41
+ # if using 'linear or 'scaled_linear', beta_start and beta_end matters, if cosine, beta_start and beta_end are ignored
42
+ beta_start: 0.00085
43
+ beta_end: 0.012
44
+ # 'linear' or 'squaredcos_cap_v2' or 'scaled_linear'
45
+ beta_schedule: 'squaredcos_cap_v2'
46
+ prediction_type: 'sample'
47
+ clip_sample: false
48
+ # 'leading' or 'trailing' or 'linspace'
49
+ timestep_spacing: 'leading'
50
+ # below are for ddim
51
+ set_alpha_to_one: True
52
+ steps_offset: 0
53
+
54
+
55
+ # use ddpm scheduler
56
+ # scheduler:
57
+ # target: diffusers.DDPMScheduler
58
+ # num_inference_steps: 50
59
+ # eta: 0.0
60
+ # params:
61
+ # num_train_timesteps: 1000
62
+ # beta_start: 0.00085
63
+ # beta_end: 0.012
64
+ # beta_schedule: 'squaredcos_cap_v2' # 'squaredcos_cap_v2'
65
+ # prediction_type: 'sample'
66
+ # clip_sample: false
67
+ # variance_type: 'fixed_small_log'
68
+ # # below are for ddim
69
+ # # set_alpha_to_one: True
70
+ # # steps_offset: 1
71
+
configs/sc_model_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_name: LSM
3
+ g_name: GestureLSM
4
+ do_classifier_free_guidance: False
5
+ guidance_scale: 2
6
+ n_steps: 20
7
+ use_exp: False
8
+
9
+ denoiser:
10
+ target: models.denoiser.GestureDenoiser
11
+ params:
12
+ input_dim: 128
13
+ latent_dim: 256
14
+ ff_size: 1024
15
+ num_layers: 8
16
+ num_heads: 4
17
+ dropout: 0.1
18
+ activation: "gelu"
19
+ n_seed: 8
20
+ flip_sin_to_cos: True
21
+ freq_shift: 0.0
22
+ cond_proj_dim: 256
23
+ use_exp: ${model.use_exp}
24
+
25
+
26
+ modality_encoder:
27
+ target: models.modality_encoder.ModalityEncoder
28
+ params:
29
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
30
+ t_fix_pre: False
31
+ audio_dim: 256
32
+ audio_in: 2
33
+ raw_audio: False
34
+ latent_dim: 256
35
+ audio_fps: 30
36
+ use_exp: ${model.use_exp}
37
+
configs/sc_model_holistic_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_name: LSM
3
+ g_name: GestureLSM
4
+ do_classifier_free_guidance: False
5
+ guidance_scale: 2
6
+ n_steps: 25
7
+ use_exp: True
8
+
9
+ denoiser:
10
+ target: models.denoiser.GestureDenoiser
11
+ params:
12
+ input_dim: 128
13
+ latent_dim: 256
14
+ ff_size: 1024
15
+ num_layers: 8
16
+ num_heads: 4
17
+ dropout: 0.1
18
+ activation: "gelu"
19
+ n_seed: 8
20
+ flip_sin_to_cos: True
21
+ freq_shift: 0.0
22
+ cond_proj_dim: 256
23
+ use_exp: ${model.use_exp}
24
+
25
+
26
+ modality_encoder:
27
+ target: models.modality_encoder.ModalityEncoder
28
+ params:
29
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
30
+ t_fix_pre: False
31
+ audio_dim: 256
32
+ audio_in: 2
33
+ raw_audio: False
34
+ latent_dim: 256
35
+ audio_fps: 30
36
+ use_exp: ${model.use_exp}
37
+
configs/sc_reflow_model_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_name: LSM
3
+ g_name: GestureLSM
4
+ do_classifier_free_guidance: False
5
+ guidance_scale: 2
6
+ n_steps: 2
7
+ use_exp: False
8
+
9
+ denoiser:
10
+ target: models.denoiser.GestureDenoiser
11
+ params:
12
+ input_dim: 128
13
+ latent_dim: 256
14
+ ff_size: 1024
15
+ num_layers: 8
16
+ num_heads: 4
17
+ dropout: 0.1
18
+ activation: "gelu"
19
+ n_seed: 8
20
+ flip_sin_to_cos: True
21
+ freq_shift: 0.0
22
+ cond_proj_dim: 256
23
+ use_exp: ${model.use_exp}
24
+
25
+
26
+ modality_encoder:
27
+ target: models.modality_encoder.ModalityEncoder
28
+ params:
29
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
30
+ t_fix_pre: False
31
+ audio_dim: 256
32
+ audio_in: 2
33
+ raw_audio: False
34
+ latent_dim: 256
35
+ audio_fps: 30
36
+ use_exp: ${model.use_exp}
37
+
configs/shortcut.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_shortcut.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+ vqvae_face_path: ./ckpt/net_300000_face.pth
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_lower
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_hf.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_shortcut.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_single
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_holistic.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_shortcut_holistic.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_holistic_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+ vqvae_face_path: ./ckpt/net_300000_face.pth
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_lower
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_reflow.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./outputs/audio2pose/custom/0212_125039_shortcut_reflow/last_20.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+ vqvae_face_path: ./ckpt/net_300000_face.pth
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_reflow
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 1
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_reflow_test.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/shortcut_reflow.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_reflow_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+ vqvae_face_path: ./ckpt/net_300000_face.pth
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_lower
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 1
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_rvqvae_128.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_shortcut.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_lower
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
configs/shortcut_rvqvae_128_hf.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_train: True
2
+ ddp: False
3
+ stat: ts
4
+ root_path: ./
5
+ out_path: ./outputs/audio2pose/
6
+ project: s2g
7
+ e_path: weights/AESKConv_240_100.bin
8
+ eval_model: motion_representation
9
+ e_name: VAESKConv
10
+ data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
11
+ test_ckpt: ./ckpt/new_540_shortcut.bin
12
+ data_path_1: ./datasets/hub/
13
+ pose_norm: True
14
+ cfg: configs/sc_model_config.yaml
15
+
16
+
17
+ mean_pose_path: ./mean_std/beatx_2_330_mean.npy
18
+ std_pose_path: ./mean_std/beatx_2_330_std.npy
19
+
20
+ mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
21
+ std_trans_path: ./mean_std/beatx_2_trans_std.npy
22
+
23
+
24
+ vqvae_upper_path: ./ckpt/net_300000_upper.pth
25
+ vqvae_hands_path: ./ckpt/net_300000_hands.pth
26
+ vqvae_lower_path: ./ckpt/net_300000_lower.pth
27
+
28
+ vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
29
+ use_trans: True
30
+
31
+ decay_epoch: 500
32
+
33
+ vqvae_squeeze_scale: 4
34
+ vqvae_latent_scale: 5
35
+
36
+ vae_test_len: 32
37
+ vae_test_dim: 330
38
+ vae_test_stride: 20
39
+ vae_length: 240
40
+ vae_codebook_size: 256
41
+ vae_layer: 4
42
+ vae_grow: [1,1,2,1]
43
+ variational: False
44
+
45
+ # data config
46
+ training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
47
+ additional_data: False
48
+ cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
49
+ dataset: beat_sep_single
50
+ new_cache: False
51
+
52
+ # motion config
53
+ ori_joints: beat_smplx_joints
54
+ tar_joints: beat_smplx_full
55
+ pose_rep: smplxflame_30
56
+ pose_fps: 30
57
+ rot6d: True
58
+ pre_frames: 4
59
+ pose_dims: 330
60
+ pose_length: 128
61
+ stride: 20
62
+ test_length: 128
63
+ m_fix_pre: False
64
+
65
+
66
+ audio_rep: onset+amplitude
67
+ audio_sr: 16000
68
+ audio_fps: 16000
69
+ audio_norm: False
70
+ audio_f: 256
71
+ audio_raw: None
72
+
73
+
74
+ word_rep: textgrid
75
+ word_dims: 300
76
+ t_pre_encoder: fasttext
77
+
78
+
79
+ facial_rep: smplxflame_30
80
+ facial_dims: 100
81
+ facial_norm: False
82
+ facial_f: 0
83
+
84
+
85
+ id_rep: onehot
86
+ speaker_f: 0
87
+
88
+
89
+ batch_size: 128
90
+ lr_base: 2e-4
91
+ trainer: shortcut_rvqvae
92
+
93
+ rec_weight: 1
94
+ grad_norm: 0.99
95
+ epochs: 1000
96
+ test_period: 20
dataloaders/__pycache__/beat_sep_single.cpython-312.pyc ADDED
Binary file (42 kB). View file
 
dataloaders/__pycache__/build_vocab.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
dataloaders/__pycache__/data_tools.cpython-312.pyc ADDED
Binary file (43.4 kB). View file
 
dataloaders/beat_dataset_new.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import textgrid as tg
8
+ import pandas as pd
9
+ import torch
10
+ import glob
11
+ import json
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ import pickle
18
+ import smplx
19
+ from .utils.audio_features import AudioProcessor
20
+ from .utils.other_tools import MultiLMDBManager
21
+ from .utils.motion_rep_transfer import process_smplx_motion
22
+ from .utils.mis_features import process_semantic_data, process_emotion_data
23
+ from .utils.text_features import process_word_data
24
+ from .utils.data_sample import sample_from_clip
25
+ from .utils import rotation_conversions as rc
26
+
27
+ class CustomDataset(Dataset):
28
+ def __init__(self, args, loader_type, build_cache=True):
29
+ self.args = args
30
+ self.loader_type = loader_type
31
+ self.rank = dist.get_rank()
32
+
33
+ self.ori_stride = self.args.stride
34
+ self.ori_length = self.args.pose_length
35
+
36
+ # Initialize basic parameters
37
+ self.ori_stride = self.args.stride
38
+ self.ori_length = self.args.pose_length
39
+ self.alignment = [0,0] # for trinity
40
+
41
+ # Initialize SMPLX model
42
+ self.smplx = smplx.create(
43
+ self.args.data_path_1+"smplx_models/",
44
+ model_type='smplx',
45
+ gender='NEUTRAL_2020',
46
+ use_face_contour=False,
47
+ num_betas=300,
48
+ num_expression_coeffs=100,
49
+ ext='npz',
50
+ use_pca=False,
51
+ ).cuda().eval()
52
+
53
+ self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
54
+
55
+ # Load and process split rules
56
+ self._process_split_rules()
57
+
58
+ # Initialize data directories and lengths
59
+ self._init_data_paths()
60
+
61
+ # Build or load cache
62
+ self._init_cache(build_cache)
63
+
64
+
65
+ def _process_split_rules(self):
66
+ """Process dataset split rules."""
67
+ split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv")
68
+ self.selected_file = split_rule.loc[
69
+ (split_rule['type'] == self.loader_type) &
70
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
71
+ ]
72
+
73
+ if self.args.additional_data and self.loader_type == 'train':
74
+ split_b = split_rule.loc[
75
+ (split_rule['type'] == 'additional') &
76
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
77
+ ]
78
+ self.selected_file = pd.concat([self.selected_file, split_b])
79
+
80
+ if self.selected_file.empty:
81
+ logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
82
+ self.selected_file = split_rule.loc[
83
+ (split_rule['type'] == 'train') &
84
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
85
+ ]
86
+ self.selected_file = self.selected_file.iloc[0:8]
87
+
88
+ def _init_data_paths(self):
89
+ """Initialize data directories and lengths."""
90
+ self.data_dir = self.args.data_path
91
+
92
+ if self.loader_type == "test":
93
+ self.args.multi_length_training = [1.0]
94
+
95
+ self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1])
96
+ self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr)
97
+
98
+ if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr:
99
+ self.max_audio_pre_len = self.args.test_length * self.args.audio_sr
100
+
101
+ if self.args.test_clip and self.loader_type == "test":
102
+ self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache"
103
+ else:
104
+ self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache"
105
+
106
+
107
+
108
+ def _init_cache(self, build_cache):
109
+ """Initialize or build cache."""
110
+ self.lmdb_envs = {}
111
+ self.mapping_data = None
112
+
113
+ if build_cache and self.rank == 0:
114
+ self.build_cache(self.preloaded_dir)
115
+
116
+ self.load_db_mapping()
117
+
118
+ def build_cache(self, preloaded_dir):
119
+ """Build the dataset cache."""
120
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
121
+ logger.info("Reading data '{}'...".format(self.data_dir))
122
+ logger.info("Creating the dataset cache...")
123
+
124
+ if self.args.new_cache and os.path.exists(preloaded_dir):
125
+ shutil.rmtree(preloaded_dir)
126
+
127
+ if os.path.exists(preloaded_dir):
128
+ # if the dir is empty, that means we still need to build the cache
129
+ if not os.listdir(preloaded_dir):
130
+ self.cache_generation(
131
+ preloaded_dir,
132
+ self.args.disable_filtering,
133
+ self.args.clean_first_seconds,
134
+ self.args.clean_final_seconds,
135
+ is_test=False
136
+ )
137
+ else:
138
+ logger.info("Found the cache {}".format(preloaded_dir))
139
+
140
+ elif self.loader_type == "test":
141
+ self.cache_generation(preloaded_dir, True, 0, 0, is_test=True)
142
+ else:
143
+ self.cache_generation(
144
+ preloaded_dir,
145
+ self.args.disable_filtering,
146
+ self.args.clean_first_seconds,
147
+ self.args.clean_final_seconds,
148
+ is_test=False
149
+ )
150
+
151
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
152
+ """Generate cache for the dataset."""
153
+ if not os.path.exists(out_lmdb_dir):
154
+ os.makedirs(out_lmdb_dir)
155
+
156
+ self.audio_processor = AudioProcessor(layer=self.args.n_layer, use_distill=self.args.use_distill)
157
+
158
+ # Initialize the multi-LMDB manager
159
+ lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024)
160
+
161
+ self.n_out_samples = 0
162
+ n_filtered_out = defaultdict(int)
163
+
164
+ for index, file_name in self.selected_file.iterrows():
165
+ f_name = file_name["id"]
166
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
167
+ pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext)
168
+
169
+ # Process data
170
+ data = self._process_file_data(f_name, pose_file, ext)
171
+ if data is None:
172
+ continue
173
+
174
+ # Sample from clip
175
+ filtered_result, self.n_out_samples = sample_from_clip(
176
+ lmdb_manager=lmdb_manager,
177
+ audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"),
178
+ audio_each_file=data['audio_tensor'],
179
+ high_each_file=data['high_level'],
180
+ low_each_file=data['low_level'],
181
+ pose_each_file=data['pose'],
182
+ rep15d_each_file=data['rep15d'],
183
+ trans_each_file=data['trans'],
184
+ trans_v_each_file=data['trans_v'],
185
+ shape_each_file=data['shape'],
186
+ facial_each_file=data['facial'],
187
+ aligned_text_each_file=data['aligned_text'],
188
+ word_each_file=data['word'] if self.args.word_rep is not None else None,
189
+ vid_each_file=data['vid'],
190
+ emo_each_file=data['emo'],
191
+ sem_each_file=data['sem'],
192
+ intention_each_file=data['intention'] if data['intention'] is not None else None,
193
+ audio_onset_each_file=data['audio_onset'] if self.args.onset_rep else None,
194
+ args=self.args,
195
+ ori_stride=self.ori_stride,
196
+ ori_length=self.ori_length,
197
+ disable_filtering=disable_filtering,
198
+ clean_first_seconds=clean_first_seconds,
199
+ clean_final_seconds=clean_final_seconds,
200
+ is_test=is_test,
201
+ n_out_samples=self.n_out_samples
202
+ )
203
+
204
+ for type_key in filtered_result:
205
+ n_filtered_out[type_key] += filtered_result[type_key]
206
+
207
+ lmdb_manager.close()
208
+
209
+ def _process_file_data(self, f_name, pose_file, ext):
210
+ """Process all data for a single file."""
211
+ data = {
212
+ 'pose': None, 'trans': None, 'trans_v': None, 'shape': None,
213
+ 'audio': None, 'facial': None, 'word': None, 'emo': None,
214
+ 'sem': None, 'vid': None
215
+ }
216
+
217
+ # Process motion data
218
+ logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue"))
219
+ if "smplx" in self.args.pose_rep:
220
+ motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep)
221
+ else:
222
+ raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.")
223
+
224
+ if motion_data is None:
225
+ return None
226
+
227
+ data.update(motion_data)
228
+
229
+ # Process speaker ID
230
+ if self.args.id_rep is not None:
231
+ speaker_id = int(f_name.split("_")[0]) - 1
232
+ data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0)
233
+ else:
234
+ data['vid'] = np.array([-1])
235
+
236
+ # Process audio if needed
237
+ if self.args.audio_rep is not None:
238
+ audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
239
+ audio_data = self.audio_processor.get_wav2vec_from_16k_wav(audio_file, aligned_text=True)
240
+ if audio_data is None:
241
+ return None
242
+ data.update(audio_data)
243
+
244
+ if getattr(self.args, "onset_rep", False):
245
+ audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
246
+ onset_data = self.audio_processor.calculate_onset_amplitude(audio_file, data)
247
+ if onset_data is None:
248
+ return None
249
+ data.update(onset_data)
250
+
251
+ # Process emotion if needed
252
+ if self.args.emo_rep is not None:
253
+ data = process_emotion_data(f_name, data, self.args)
254
+ if data is None:
255
+ return None
256
+
257
+ # Process word data if needed
258
+ if self.args.word_rep is not None:
259
+ word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid"
260
+ data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file)
261
+ if data is None:
262
+ return None
263
+
264
+
265
+ # Process semantic data if needed
266
+ if self.args.sem_rep is not None:
267
+ sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt"
268
+ data = process_semantic_data(sem_file, self.args, data, f_name)
269
+ if data is None:
270
+ return None
271
+
272
+ return data
273
+
274
+ def load_db_mapping(self):
275
+ """Load database mapping from file."""
276
+ mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
277
+ with open(mapping_path, 'rb') as f:
278
+ self.mapping_data = pickle.load(f)
279
+
280
+
281
+ # Update paths from test to test_clip if needed
282
+ if self.loader_type == "test" and self.args.test_clip:
283
+ updated_paths = []
284
+ for path in self.mapping_data['db_paths']:
285
+ updated_path = path.replace("test/", "test_clip/")
286
+ updated_paths.append(updated_path)
287
+ self.mapping_data['db_paths'] = updated_paths
288
+
289
+ # Re-save the updated mapping_data to the same pickle file
290
+ with open(mapping_path, 'wb') as f:
291
+ pickle.dump(self.mapping_data, f)
292
+
293
+ self.n_samples = len(self.mapping_data['mapping'])
294
+
295
+ def get_lmdb_env(self, db_idx):
296
+ """Get LMDB environment for given database index."""
297
+ if db_idx not in self.lmdb_envs:
298
+ db_path = self.mapping_data['db_paths'][db_idx]
299
+ self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False)
300
+ return self.lmdb_envs[db_idx]
301
+
302
+ def __len__(self):
303
+ """Return the total number of samples in the dataset."""
304
+ return self.n_samples
305
+
306
+ def __getitem__(self, idx):
307
+ """Get a single sample from the dataset."""
308
+ db_idx = self.mapping_data['mapping'][idx]
309
+ lmdb_env = self.get_lmdb_env(db_idx)
310
+
311
+ with lmdb_env.begin(write=False) as txn:
312
+ key = "{:008d}".format(idx).encode("ascii")
313
+ sample = txn.get(key)
314
+ sample = pickle.loads(sample)
315
+
316
+
317
+ tar_pose, in_audio, in_audio_high, in_audio_low, tar_rep15d, in_facial, in_shape, in_aligned_text, in_word, emo, sem, vid, trans, trans_v, intention, audio_name, audio_onset = sample
318
+
319
+
320
+ # Convert data to tensors with appropriate types
321
+ processed_data = self._convert_to_tensors(
322
+ tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word,
323
+ emo, sem, vid, trans, trans_v, intention, audio_onset
324
+ )
325
+
326
+ processed_data['audio_name'] = audio_name
327
+ return processed_data
328
+
329
+ def _convert_to_tensors(self, tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word,
330
+ emo, sem, vid, trans, trans_v, intention=None, audio_onset=None):
331
+ """Convert numpy arrays to tensors with appropriate types."""
332
+ data = {
333
+ 'emo': torch.from_numpy(emo).int(),
334
+ 'sem': torch.from_numpy(sem).float(),
335
+ 'audio_tensor': torch.from_numpy(in_audio).float(),
336
+ 'bert_time_aligned': torch.from_numpy(in_aligned_text).float()
337
+ }
338
+ tar_pose = torch.from_numpy(tar_pose).float()
339
+
340
+ if self.loader_type == "test":
341
+ data.update({
342
+ 'pose': tar_pose,
343
+ 'rep15d': torch.from_numpy(tar_rep15d).float(),
344
+ 'trans': torch.from_numpy(trans).float(),
345
+ 'trans_v': torch.from_numpy(trans_v).float(),
346
+ 'facial': torch.from_numpy(in_facial).float(),
347
+ 'id': torch.from_numpy(vid).float(),
348
+ 'beta': torch.from_numpy(in_shape).float()
349
+ })
350
+ else:
351
+ data.update({
352
+ 'pose': tar_pose,
353
+ 'rep15d': torch.from_numpy(tar_rep15d).reshape((tar_rep15d.shape[0], -1)).float(),
354
+ 'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(),
355
+ 'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(),
356
+ 'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(),
357
+ 'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(),
358
+ 'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
359
+ })
360
+
361
+
362
+ # Handle audio onset
363
+ if audio_onset is not None:
364
+ data['audio_onset'] = torch.from_numpy(audio_onset).float()
365
+ else:
366
+ data['audio_onset'] = torch.tensor([-1])
367
+
368
+ if in_word is not None:
369
+ data['word'] = torch.from_numpy(in_word).int()
370
+ else:
371
+ data['word'] = torch.tensor([-1])
372
+
373
+ return data
dataloaders/beat_sep.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import textgrid as tg
8
+ import pandas as pd
9
+ import torch
10
+ import glob
11
+ import json
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ #import pyarrow
18
+ import pickle
19
+ import librosa
20
+ import smplx
21
+
22
+ from .build_vocab import Vocab
23
+ from .utils.audio_features import Wav2Vec2Model
24
+ from .data_tools import joints_list
25
+ from .utils import rotation_conversions as rc
26
+ from .utils import other_tools
27
+
28
+ class CustomDataset(Dataset):
29
+ def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
30
+ self.args = args
31
+ self.loader_type = loader_type
32
+
33
+ self.rank = dist.get_rank()
34
+ self.ori_stride = self.args.stride
35
+ self.ori_length = self.args.pose_length
36
+ self.alignment = [0,0] # for trinity
37
+
38
+ self.ori_joint_list = joints_list[self.args.ori_joints]
39
+ self.tar_joint_list = joints_list[self.args.tar_joints]
40
+ if 'smplx' in self.args.pose_rep:
41
+ self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
42
+ self.joints = len(list(self.tar_joint_list.keys()))
43
+ for joint_name in self.tar_joint_list:
44
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
45
+ else:
46
+ self.joints = len(list(self.ori_joint_list.keys()))+1
47
+ self.joint_mask = np.zeros(self.joints*3)
48
+ for joint_name in self.tar_joint_list:
49
+ if joint_name == "Hips":
50
+ self.joint_mask[3:6] = 1
51
+ else:
52
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
53
+ # select trainable joints
54
+
55
+ split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
56
+ self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
57
+ if args.additional_data and loader_type == 'train':
58
+ split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
59
+ #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
60
+ self.selected_file = pd.concat([self.selected_file, split_b])
61
+ if self.selected_file.empty:
62
+ logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
63
+ self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
64
+ self.selected_file = self.selected_file.iloc[0:8]
65
+ self.data_dir = args.data_path
66
+
67
+ if loader_type == "test":
68
+ self.args.multi_length_training = [1.0]
69
+ self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
70
+ self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
71
+ if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
72
+ self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
73
+
74
+ if args.word_rep is not None:
75
+ with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
76
+ self.lang_model = pickle.load(f)
77
+
78
+ preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
79
+ # if args.pose_norm:
80
+ # # careful for rotation vectors
81
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
82
+ # self.calculate_mean_pose()
83
+ # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy")
84
+ # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy")
85
+ # if args.audio_norm:
86
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"):
87
+ # self.calculate_mean_audio()
88
+ # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy")
89
+ # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy")
90
+ # if args.facial_norm:
91
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
92
+ # self.calculate_mean_face()
93
+ # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")
94
+ # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")
95
+ if self.args.beat_align:
96
+ if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
97
+ self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
98
+ self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
99
+
100
+ if build_cache and self.rank == 0:
101
+ self.build_cache(preloaded_dir)
102
+ self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
103
+ with self.lmdb_env.begin() as txn:
104
+ self.n_samples = txn.stat()["entries"]
105
+
106
+
107
+ def calculate_mean_velocity(self, save_path):
108
+ self.smplx = smplx.create(
109
+ self.args.data_path_1+"smplx_models/",
110
+ model_type='smplx',
111
+ gender='NEUTRAL_2020',
112
+ use_face_contour=False,
113
+ num_betas=300,
114
+ num_expression_coeffs=100,
115
+ ext='npz',
116
+ use_pca=False,
117
+ ).cuda().eval()
118
+ dir_p = self.data_dir + self.args.pose_rep + "/"
119
+ all_list = []
120
+ from tqdm import tqdm
121
+ for tar in tqdm(os.listdir(dir_p)):
122
+ if tar.endswith(".npz"):
123
+ m_data = np.load(dir_p+tar, allow_pickle=True)
124
+ betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
125
+ n, c = poses.shape[0], poses.shape[1]
126
+ betas = betas.reshape(1, 300)
127
+ betas = np.tile(betas, (n, 1))
128
+ betas = torch.from_numpy(betas).cuda().float()
129
+ poses = torch.from_numpy(poses.reshape(n, c)).cuda().float()
130
+ exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float()
131
+ trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float()
132
+ max_length = 128
133
+ s, r = n//max_length, n%max_length
134
+ #print(n, s, r)
135
+ all_tensor = []
136
+ for i in range(s):
137
+ with torch.no_grad():
138
+ joints = self.smplx(
139
+ betas=betas[i*max_length:(i+1)*max_length],
140
+ transl=trans[i*max_length:(i+1)*max_length],
141
+ expression=exps[i*max_length:(i+1)*max_length],
142
+ jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
143
+ global_orient=poses[i*max_length:(i+1)*max_length,:3],
144
+ body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
145
+ left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
146
+ right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
147
+ return_verts=True,
148
+ return_joints=True,
149
+ leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
150
+ reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
151
+ )['joints'][:, :55, :].reshape(max_length, 55*3)
152
+ all_tensor.append(joints)
153
+ if r != 0:
154
+ with torch.no_grad():
155
+ joints = self.smplx(
156
+ betas=betas[s*max_length:s*max_length+r],
157
+ transl=trans[s*max_length:s*max_length+r],
158
+ expression=exps[s*max_length:s*max_length+r],
159
+ jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
160
+ global_orient=poses[s*max_length:s*max_length+r,:3],
161
+ body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
162
+ left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
163
+ right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
164
+ return_verts=True,
165
+ return_joints=True,
166
+ leye_pose=poses[s*max_length:s*max_length+r, 69:72],
167
+ reye_pose=poses[s*max_length:s*max_length+r, 72:75],
168
+ )['joints'][:, :55, :].reshape(r, 55*3)
169
+ all_tensor.append(joints)
170
+ joints = torch.cat(all_tensor, axis=0)
171
+ joints = joints.permute(1, 0)
172
+ dt = 1/30
173
+ # first steps is forward diff (t+1 - t) / dt
174
+ init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
175
+ # middle steps are second order (t+1 - t-1) / 2dt
176
+ middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
177
+ # last step is backward diff (t - t-1) / dt
178
+ final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
179
+ #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape)
180
+ vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3)
181
+ #print(vel_seq.shape)
182
+ #.permute(1, 0).reshape(n, 55, 3)
183
+ vel_seq_np = vel_seq.cpu().numpy()
184
+ vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55
185
+ all_list.append(vel_joints_np)
186
+ avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55
187
+ np.save(save_path, avg_vel)
188
+
189
+
190
+ def build_cache(self, preloaded_dir):
191
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
192
+ logger.info("Reading data '{}'...".format(self.data_dir))
193
+ logger.info("Creating the dataset cache...")
194
+ if self.args.new_cache:
195
+ if os.path.exists(preloaded_dir):
196
+ shutil.rmtree(preloaded_dir)
197
+ if os.path.exists(preloaded_dir):
198
+ logger.info("Found the cache {}".format(preloaded_dir))
199
+ elif self.loader_type == "test":
200
+ self.cache_generation(
201
+ preloaded_dir, True,
202
+ 0, 0,
203
+ is_test=True)
204
+ else:
205
+ self.cache_generation(
206
+ preloaded_dir, self.args.disable_filtering,
207
+ self.args.clean_first_seconds, self.args.clean_final_seconds,
208
+ is_test=False)
209
+
210
+ def __len__(self):
211
+ return self.n_samples
212
+
213
+
214
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
215
+ # if "wav2vec2" in self.args.audio_rep:
216
+ # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
217
+ # self.wav2vec_model.feature_extractor._freeze_parameters()
218
+ # self.wav2vec_model = self.wav2vec_model.cuda()
219
+ # self.wav2vec_model.eval()
220
+
221
+ self.n_out_samples = 0
222
+ # create db for samples
223
+ if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
224
+ dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
225
+ n_filtered_out = defaultdict(int)
226
+
227
+ for index, file_name in self.selected_file.iterrows():
228
+ f_name = file_name["id"]
229
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
230
+ pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
231
+ pose_each_file = []
232
+ trans_each_file = []
233
+ shape_each_file = []
234
+ audio_each_file = []
235
+ facial_each_file = []
236
+ word_each_file = []
237
+ emo_each_file = []
238
+ sem_each_file = []
239
+ vid_each_file = []
240
+ id_pose = f_name #1_wayne_0_1_1
241
+
242
+ logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
243
+ if "smplx" in self.args.pose_rep:
244
+ pose_data = np.load(pose_file, allow_pickle=True)
245
+ assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
246
+ stride = int(30/self.args.pose_fps)
247
+ pose_each_file = pose_data["poses"][::stride] * self.joint_mask
248
+ pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
249
+ # print(pose_each_file.shape)
250
+ trans_each_file = pose_data["trans"][::stride]
251
+ shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
252
+ if self.args.facial_rep is not None:
253
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
254
+ facial_each_file = pose_data["expressions"][::stride]
255
+ if self.args.facial_norm:
256
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
257
+
258
+ else:
259
+ assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
260
+ stride = int(120/self.args.pose_fps)
261
+ with open(pose_file, "r") as pose_data:
262
+ for j, line in enumerate(pose_data.readlines()):
263
+ if j < 431: continue
264
+ if j%stride != 0:continue
265
+ data = np.fromstring(line, dtype=float, sep=" ")
266
+ rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ")
267
+ rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3)
268
+ rot_data = rot_data.numpy() * self.joint_mask
269
+
270
+ pose_each_file.append(rot_data)
271
+ trans_each_file.append(data[:3])
272
+
273
+ pose_each_file = np.array(pose_each_file)
274
+ # print(pose_each_file.shape)
275
+ trans_each_file = np.array(trans_each_file)
276
+ shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
277
+ if self.args.facial_rep is not None:
278
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
279
+ facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json")
280
+ assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
281
+ stride = int(60/self.args.pose_fps)
282
+ if not os.path.exists(facial_file):
283
+ logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #")
284
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
285
+ continue
286
+ with open(facial_file, 'r') as facial_data_file:
287
+ facial_data = json.load(facial_data_file)
288
+ for j, frame_data in enumerate(facial_data['frames']):
289
+ if j%stride != 0:continue
290
+ facial_each_file.append(frame_data['weights'])
291
+ facial_each_file = np.array(facial_each_file)
292
+ if self.args.facial_norm:
293
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
294
+
295
+ if self.args.id_rep is not None:
296
+ vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
297
+
298
+ if self.args.audio_rep is not None:
299
+ logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
300
+ audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
301
+ if not os.path.exists(audio_file):
302
+ logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
303
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
304
+ continue
305
+ audio_each_file, sr = librosa.load(audio_file)
306
+ audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
307
+ if self.args.audio_rep == "onset+amplitude":
308
+ from numpy.lib import stride_tricks
309
+ frame_length = 1024
310
+ # hop_length = 512
311
+ shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
312
+ strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
313
+ rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
314
+ amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
315
+ # pad the last frame_length-1 samples
316
+ amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
317
+ audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
318
+ onset_array = np.zeros(len(audio_each_file), dtype=float)
319
+ onset_array[audio_onset_f] = 1.0
320
+ # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
321
+ audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
322
+ elif self.args.audio_rep == "mfcc":
323
+ audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps))
324
+ audio_each_file = audio_each_file.transpose(1, 0)
325
+ # print(audio_each_file.shape, pose_each_file.shape)
326
+ if self.args.audio_norm and self.args.audio_rep == "wave16k":
327
+ audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
328
+ # print(audio_each_file.shape)
329
+ time_offset = 0
330
+ if self.args.word_rep is not None:
331
+ logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
332
+ word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid"
333
+ if not os.path.exists(word_file):
334
+ logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
335
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
336
+ continue
337
+ tgrid = tg.TextGrid.fromFile(word_file)
338
+ if self.args.t_pre_encoder == "bert":
339
+ from transformers import AutoTokenizer, BertModel
340
+ tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True)
341
+ model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval()
342
+ list_word = []
343
+ all_hidden = []
344
+ max_len = 400
345
+ last = 0
346
+ word_token_mapping = []
347
+ first = True
348
+ for i, word in enumerate(tgrid[0]):
349
+ last = i
350
+ if (i%max_len != 0) or (i==0):
351
+ if word.mark == "":
352
+ list_word.append(".")
353
+ else:
354
+ list_word.append(word.mark)
355
+ else:
356
+ max_counter = max_len
357
+ str_word = ' '.join(map(str, list_word))
358
+ if first:
359
+ global_len = 0
360
+ end = -1
361
+ offset_word = []
362
+ for k, wordvalue in enumerate(list_word):
363
+ start = end+1
364
+ end = start+len(wordvalue)
365
+ offset_word.append((start, end))
366
+ #print(offset_word)
367
+ token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
368
+ #print(token_scan)
369
+ for start, end in offset_word:
370
+ sub_mapping = []
371
+ for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
372
+ if int(start) <= int(start_t) and int(end_t) <= int(end):
373
+ #print(i+global_len)
374
+ sub_mapping.append(i+global_len)
375
+ word_token_mapping.append(sub_mapping)
376
+ #print(len(word_token_mapping))
377
+ global_len = word_token_mapping[-1][-1] + 1
378
+ list_word = []
379
+ if word.mark == "":
380
+ list_word.append(".")
381
+ else:
382
+ list_word.append(word.mark)
383
+
384
+ with torch.no_grad():
385
+ inputs = tokenizer(str_word, return_tensors="pt")
386
+ outputs = model(**inputs)
387
+ last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
388
+ all_hidden.append(last_hidden_states)
389
+
390
+ #list_word = list_word[:10]
391
+ if list_word == []:
392
+ pass
393
+ else:
394
+ if first:
395
+ global_len = 0
396
+ str_word = ' '.join(map(str, list_word))
397
+ end = -1
398
+ offset_word = []
399
+ for k, wordvalue in enumerate(list_word):
400
+ start = end+1
401
+ end = start+len(wordvalue)
402
+ offset_word.append((start, end))
403
+ #print(offset_word)
404
+ token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
405
+ #print(token_scan)
406
+ for start, end in offset_word:
407
+ sub_mapping = []
408
+ for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
409
+ if int(start) <= int(start_t) and int(end_t) <= int(end):
410
+ sub_mapping.append(i+global_len)
411
+ #print(sub_mapping)
412
+ word_token_mapping.append(sub_mapping)
413
+ #print(len(word_token_mapping))
414
+ with torch.no_grad():
415
+ inputs = tokenizer(str_word, return_tensors="pt")
416
+ outputs = model(**inputs)
417
+ last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
418
+ all_hidden.append(last_hidden_states)
419
+ last_hidden_states = np.concatenate(all_hidden, axis=0)
420
+
421
+ for i in range(pose_each_file.shape[0]):
422
+ found_flag = False
423
+ current_time = i/self.args.pose_fps + time_offset
424
+ j_last = 0
425
+ for j, word in enumerate(tgrid[0]):
426
+ word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
427
+ if word_s<=current_time and current_time<=word_e:
428
+ if self.args.word_cache and self.args.t_pre_encoder == 'bert':
429
+ mapping_index = word_token_mapping[j]
430
+ #print(mapping_index, word_s, word_e)
431
+ s_t = np.linspace(word_s, word_e, len(mapping_index)+1)
432
+ #print(s_t)
433
+ for tt, t_sep in enumerate(s_t[1:]):
434
+ if current_time <= t_sep:
435
+ #if len(mapping_index) > 1: print(mapping_index[tt])
436
+ word_each_file.append(last_hidden_states[mapping_index[tt]])
437
+ break
438
+ else:
439
+ if word_n == " ":
440
+ word_each_file.append(self.lang_model.PAD_token)
441
+ else:
442
+ word_each_file.append(self.lang_model.get_word_index(word_n))
443
+ found_flag = True
444
+ j_last = j
445
+ break
446
+ else: continue
447
+ if not found_flag:
448
+ if self.args.word_cache and self.args.t_pre_encoder == 'bert':
449
+ word_each_file.append(last_hidden_states[j_last])
450
+ else:
451
+ word_each_file.append(self.lang_model.UNK_token)
452
+ word_each_file = np.array(word_each_file)
453
+ #print(word_each_file.shape)
454
+
455
+ if self.args.emo_rep is not None:
456
+ logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
457
+ rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
458
+ if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
459
+ if start >= 1 and start <= 64:
460
+ score = 0
461
+ elif start >= 65 and start <= 72:
462
+ score = 1
463
+ elif start >= 73 and start <= 80:
464
+ score = 2
465
+ elif start >= 81 and start <= 86:
466
+ score = 3
467
+ elif start >= 87 and start <= 94:
468
+ score = 4
469
+ elif start >= 95 and start <= 102:
470
+ score = 5
471
+ elif start >= 103 and start <= 110:
472
+ score = 6
473
+ elif start >= 111 and start <= 118:
474
+ score = 7
475
+ else: pass
476
+ else:
477
+ # you may denote as unknown in the future
478
+ score = 0
479
+ emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
480
+ #print(emo_each_file)
481
+
482
+ if self.args.sem_rep is not None:
483
+ logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
484
+ sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
485
+ sem_all = pd.read_csv(sem_file,
486
+ sep='\t',
487
+ names=["name", "start_time", "end_time", "duration", "score", "keywords"])
488
+ # we adopt motion-level semantic score here.
489
+ for i in range(pose_each_file.shape[0]):
490
+ found_flag = False
491
+ for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
492
+ current_time = i/self.args.pose_fps + time_offset
493
+ if start<=current_time and current_time<=end:
494
+ sem_each_file.append(score)
495
+ found_flag=True
496
+ break
497
+ else: continue
498
+ if not found_flag: sem_each_file.append(0.)
499
+ sem_each_file = np.array(sem_each_file)
500
+ #print(sem_each_file)
501
+
502
+ filtered_result = self._sample_from_clip(
503
+ dst_lmdb_env,
504
+ audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
505
+ vid_each_file, emo_each_file, sem_each_file,
506
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
507
+ )
508
+ for type in filtered_result.keys():
509
+ n_filtered_out[type] += filtered_result[type]
510
+
511
+ with dst_lmdb_env.begin() as txn:
512
+ logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
513
+ n_total_filtered = 0
514
+ for type, n_filtered in n_filtered_out.items():
515
+ logger.info("{}: {}".format(type, n_filtered))
516
+ n_total_filtered += n_filtered
517
+ logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
518
+ n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
519
+ dst_lmdb_env.sync()
520
+ dst_lmdb_env.close()
521
+
522
+ def _sample_from_clip(
523
+ self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
524
+ vid_each_file, emo_each_file, sem_each_file,
525
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
526
+ ):
527
+ """
528
+ for data cleaning, we ignore the data for first and final n s
529
+ for test, we return all data
530
+ """
531
+ # audio_start = int(self.alignment[0] * self.args.audio_fps)
532
+ # pose_start = int(self.alignment[1] * self.args.pose_fps)
533
+ #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
534
+ # audio_each_file = audio_each_file[audio_start:]
535
+ # pose_each_file = pose_each_file[pose_start:]
536
+ # trans_each_file =
537
+ #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
538
+ #print(pose_each_file.shape)
539
+ round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
540
+ #print(round_seconds_skeleton)
541
+ if audio_each_file != []:
542
+ if self.args.audio_rep != "wave16k":
543
+ round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
544
+ elif self.args.audio_rep == "mfcc":
545
+ round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps
546
+ else:
547
+ round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr
548
+ if facial_each_file != []:
549
+ round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
550
+ logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
551
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
552
+ max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
553
+ if round_seconds_skeleton != max_round:
554
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
555
+ else:
556
+ logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
557
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
558
+ max_round = max(round_seconds_audio, round_seconds_skeleton)
559
+ if round_seconds_skeleton != max_round:
560
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
561
+
562
+ clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
563
+ clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
564
+ clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
565
+
566
+
567
+ for ratio in self.args.multi_length_training:
568
+ if is_test:# stride = length for test
569
+ cut_length = clip_e_f_pose - clip_s_f_pose
570
+ self.args.stride = cut_length
571
+ self.max_length = cut_length
572
+ else:
573
+ self.args.stride = int(ratio*self.ori_stride)
574
+ cut_length = int(self.ori_length*ratio)
575
+
576
+ num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
577
+ logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
578
+ logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
579
+
580
+ if audio_each_file != []:
581
+ audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
582
+ """
583
+ for audio sr = 16000, fps = 15, pose_length = 34,
584
+ audio short length = 36266.7 -> 36266
585
+ this error is fine.
586
+ """
587
+ logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
588
+
589
+ n_filtered_out = defaultdict(int)
590
+ sample_pose_list = []
591
+ sample_audio_list = []
592
+ sample_facial_list = []
593
+ sample_shape_list = []
594
+ sample_word_list = []
595
+ sample_emo_list = []
596
+ sample_sem_list = []
597
+ sample_vid_list = []
598
+ sample_trans_list = []
599
+
600
+ for i in range(num_subdivision): # cut into around 2s chip, (self npose)
601
+ start_idx = clip_s_f_pose + i * self.args.stride
602
+ fin_idx = start_idx + cut_length
603
+ sample_pose = pose_each_file[start_idx:fin_idx]
604
+ sample_trans = trans_each_file[start_idx:fin_idx]
605
+ sample_shape = shape_each_file[start_idx:fin_idx]
606
+ # print(sample_pose.shape)
607
+ if self.args.audio_rep is not None:
608
+ audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
609
+ audio_end = audio_start + audio_short_length
610
+ sample_audio = audio_each_file[audio_start:audio_end]
611
+ else:
612
+ sample_audio = np.array([-1])
613
+ sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
614
+ sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
615
+ sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
616
+ sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
617
+ sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
618
+
619
+ if sample_pose.any() != None:
620
+ # filtering motion skeleton data
621
+ sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
622
+ is_correct_motion = (sample_pose != [])
623
+ if is_correct_motion or disable_filtering:
624
+ sample_pose_list.append(sample_pose)
625
+ sample_audio_list.append(sample_audio)
626
+ sample_facial_list.append(sample_facial)
627
+ sample_shape_list.append(sample_shape)
628
+ sample_word_list.append(sample_word)
629
+ sample_vid_list.append(sample_vid)
630
+ sample_emo_list.append(sample_emo)
631
+ sample_sem_list.append(sample_sem)
632
+ sample_trans_list.append(sample_trans)
633
+ else:
634
+ n_filtered_out[filtering_message] += 1
635
+
636
+ if len(sample_pose_list) > 0:
637
+ with dst_lmdb_env.begin(write=True) as txn:
638
+ for pose, audio, facial, shape, word, vid, emo, sem, trans in zip(
639
+ sample_pose_list,
640
+ sample_audio_list,
641
+ sample_facial_list,
642
+ sample_shape_list,
643
+ sample_word_list,
644
+ sample_vid_list,
645
+ sample_emo_list,
646
+ sample_sem_list,
647
+ sample_trans_list,):
648
+ k = "{:005}".format(self.n_out_samples).encode("ascii")
649
+ v = [pose, audio, facial, shape, word, emo, sem, vid, trans]
650
+ v = pickle.dumps(v,5)
651
+ txn.put(k, v)
652
+ self.n_out_samples += 1
653
+ return n_filtered_out
654
+
655
+ def __getitem__(self, idx):
656
+ with self.lmdb_env.begin(write=False) as txn:
657
+ key = "{:005}".format(idx).encode("ascii")
658
+ sample = txn.get(key)
659
+ sample = pickle.loads(sample)
660
+ tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample
661
+ #print(in_shape)
662
+ #vid = torch.from_numpy(vid).int()
663
+ emo = torch.from_numpy(emo).int()
664
+ sem = torch.from_numpy(sem).float()
665
+ in_audio = torch.from_numpy(in_audio).float()
666
+ in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
667
+ if self.loader_type == "test":
668
+ tar_pose = torch.from_numpy(tar_pose).float()
669
+ trans = torch.from_numpy(trans).float()
670
+ in_facial = torch.from_numpy(in_facial).float()
671
+ vid = torch.from_numpy(vid).float()
672
+ in_shape = torch.from_numpy(in_shape).float()
673
+ else:
674
+ in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
675
+ trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
676
+ vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
677
+ tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
678
+ in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
679
+ return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans}
680
+
681
+
682
+ class MotionPreprocessor:
683
+ def __init__(self, skeletons):
684
+ self.skeletons = skeletons
685
+ #self.mean_pose = mean_pose
686
+ self.filtering_message = "PASS"
687
+
688
+ def get(self):
689
+ assert (self.skeletons is not None)
690
+
691
+ # filtering
692
+ if self.skeletons != []:
693
+ if self.check_pose_diff():
694
+ self.skeletons = []
695
+ self.filtering_message = "pose"
696
+ # elif self.check_spine_angle():
697
+ # self.skeletons = []
698
+ # self.filtering_message = "spine angle"
699
+ # elif self.check_static_motion():
700
+ # self.skeletons = []
701
+ # self.filtering_message = "motion"
702
+
703
+ # if self.skeletons != []:
704
+ # self.skeletons = self.skeletons.tolist()
705
+ # for i, frame in enumerate(self.skeletons):
706
+ # assert not np.isnan(self.skeletons[i]).any() # missing joints
707
+
708
+ return self.skeletons, self.filtering_message
709
+
710
+ def check_static_motion(self, verbose=True):
711
+ def get_variance(skeleton, joint_idx):
712
+ wrist_pos = skeleton[:, joint_idx]
713
+ variance = np.sum(np.var(wrist_pos, axis=0))
714
+ return variance
715
+
716
+ left_arm_var = get_variance(self.skeletons, 6)
717
+ right_arm_var = get_variance(self.skeletons, 9)
718
+
719
+ th = 0.0014 # exclude 13110
720
+ # th = 0.002 # exclude 16905
721
+ if left_arm_var < th and right_arm_var < th:
722
+ if verbose:
723
+ print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
724
+ return True
725
+ else:
726
+ if verbose:
727
+ print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
728
+ return False
729
+
730
+
731
+ def check_pose_diff(self, verbose=False):
732
+ # diff = np.abs(self.skeletons - self.mean_pose) # 186*1
733
+ # diff = np.mean(diff)
734
+
735
+ # # th = 0.017
736
+ # th = 0.02 #0.02 # exclude 3594
737
+ # if diff < th:
738
+ # if verbose:
739
+ # print("skip - check_pose_diff {:.5f}".format(diff))
740
+ # return True
741
+ # # th = 3.5 #0.02 # exclude 3594
742
+ # # if 3.5 < diff < 5:
743
+ # # if verbose:
744
+ # # print("skip - check_pose_diff {:.5f}".format(diff))
745
+ # # return True
746
+ # else:
747
+ # if verbose:
748
+ # print("pass - check_pose_diff {:.5f}".format(diff))
749
+ return False
750
+
751
+
752
+ def check_spine_angle(self, verbose=True):
753
+ def angle_between(v1, v2):
754
+ v1_u = v1 / np.linalg.norm(v1)
755
+ v2_u = v2 / np.linalg.norm(v2)
756
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
757
+
758
+ angles = []
759
+ for i in range(self.skeletons.shape[0]):
760
+ spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
761
+ angle = angle_between(spine_vec, [0, -1, 0])
762
+ angles.append(angle)
763
+
764
+ if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
765
+ # if np.rad2deg(max(angles)) > 20: # exclude 8270
766
+ if verbose:
767
+ print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
768
+ return True
769
+ else:
770
+ if verbose:
771
+ print("pass - check_spine_angle {:.5f}".format(max(angles)))
772
+ return False
dataloaders/beat_sep_lower.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import pandas as pd
8
+ import torch
9
+ import glob
10
+ import json
11
+ from dataloaders.build_vocab import Vocab
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ import pickle
18
+ import smplx
19
+ from .utils.audio_features import process_audio_data
20
+ from .data_tools import joints_list
21
+ from .utils.other_tools import MultiLMDBManager
22
+ from .utils.motion_rep_transfer import process_smplx_motion
23
+ from .utils.mis_features import process_semantic_data, process_emotion_data
24
+ from .utils.text_features import process_word_data
25
+ from .utils.data_sample import sample_from_clip
26
+ import time
27
+
28
+
29
+ class CustomDataset(Dataset):
30
+ def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
31
+ self.args = args
32
+ self.loader_type = loader_type
33
+
34
+ # Set rank safely - handle cases where distributed training is not yet initialized
35
+ try:
36
+ if torch.distributed.is_initialized():
37
+ self.rank = torch.distributed.get_rank()
38
+ else:
39
+ self.rank = 0
40
+ except:
41
+ self.rank = 0
42
+
43
+ self.ori_stride = self.args.stride
44
+ self.ori_length = self.args.pose_length
45
+
46
+ # Initialize basic parameters
47
+ self.ori_stride = self.args.stride
48
+ self.ori_length = self.args.pose_length
49
+ self.alignment = [0,0] # for trinity
50
+
51
+ """Initialize SMPLX model."""
52
+ self.smplx = smplx.create(
53
+ self.args.data_path_1+"smplx_models/",
54
+ model_type='smplx',
55
+ gender='NEUTRAL_2020',
56
+ use_face_contour=False,
57
+ num_betas=300,
58
+ num_expression_coeffs=100,
59
+ ext='npz',
60
+ use_pca=False,
61
+ ).cuda().eval()
62
+
63
+ if self.args.word_rep is not None:
64
+ with open(f"{self.args.data_path}weights/vocab.pkl", 'rb') as f:
65
+ self.lang_model = pickle.load(f)
66
+
67
+ # Load and process split rules
68
+ self._process_split_rules()
69
+
70
+ # Initialize data directories and lengths
71
+ self._init_data_paths()
72
+
73
+ if self.args.beat_align:
74
+ if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
75
+ self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
76
+ self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
77
+
78
+ # Build or load cache
79
+ self._init_cache(build_cache)
80
+
81
+ def _process_split_rules(self):
82
+ """Process dataset split rules."""
83
+ split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv")
84
+ self.selected_file = split_rule.loc[
85
+ (split_rule['type'] == self.loader_type) &
86
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
87
+ ]
88
+
89
+ if self.args.additional_data and self.loader_type == 'train':
90
+ split_b = split_rule.loc[
91
+ (split_rule['type'] == 'additional') &
92
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
93
+ ]
94
+ self.selected_file = pd.concat([self.selected_file, split_b])
95
+
96
+ if self.selected_file.empty:
97
+ logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
98
+ self.selected_file = split_rule.loc[
99
+ (split_rule['type'] == 'train') &
100
+ (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
101
+ ]
102
+ self.selected_file = self.selected_file.iloc[0:8]
103
+
104
+ def _init_data_paths(self):
105
+ """Initialize data directories and lengths."""
106
+ self.data_dir = self.args.data_path
107
+
108
+ if self.loader_type == "test":
109
+ self.args.multi_length_training = [1.0]
110
+
111
+ self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1])
112
+ self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr)
113
+
114
+ if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr:
115
+ self.max_audio_pre_len = self.args.test_length * self.args.audio_sr
116
+
117
+ if self.args.test_clip and self.loader_type == "test":
118
+ self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache"
119
+ else:
120
+ self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache"
121
+
122
+ def _init_cache(self, build_cache):
123
+ """Initialize or build cache."""
124
+ self.lmdb_envs = {}
125
+ self.mapping_data = None
126
+
127
+ if build_cache and self.rank == 0:
128
+ self.build_cache(self.preloaded_dir)
129
+
130
+ # In DDP mode, ensure all processes wait for cache building to complete
131
+ if torch.distributed.is_initialized():
132
+ torch.distributed.barrier()
133
+
134
+ # Try to regenerate cache if corrupted (only on rank 0 to avoid race conditions)
135
+ if self.rank == 0:
136
+ self.regenerate_cache_if_corrupted()
137
+
138
+ # Wait for cache regeneration to complete
139
+ if torch.distributed.is_initialized():
140
+ torch.distributed.barrier()
141
+
142
+ self.load_db_mapping()
143
+
144
+ def build_cache(self, preloaded_dir):
145
+ """Build the dataset cache."""
146
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
147
+ logger.info("Reading data '{}'...".format(self.data_dir))
148
+ logger.info("Creating the dataset cache...")
149
+
150
+ if self.args.new_cache and os.path.exists(preloaded_dir):
151
+ shutil.rmtree(preloaded_dir)
152
+
153
+ if os.path.exists(preloaded_dir):
154
+ # if the dir is empty, that means we still need to build the cache
155
+ if not os.listdir(preloaded_dir):
156
+ self.cache_generation(
157
+ preloaded_dir,
158
+ self.args.disable_filtering,
159
+ self.args.clean_first_seconds,
160
+ self.args.clean_final_seconds,
161
+ is_test=False
162
+ )
163
+ else:
164
+ logger.info("Found the cache {}".format(preloaded_dir))
165
+
166
+ elif self.loader_type == "test":
167
+ self.cache_generation(preloaded_dir, True, 0, 0, is_test=True)
168
+ else:
169
+ self.cache_generation(
170
+ preloaded_dir,
171
+ self.args.disable_filtering,
172
+ self.args.clean_first_seconds,
173
+ self.args.clean_final_seconds,
174
+ is_test=False
175
+ )
176
+
177
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
178
+ """Generate cache for the dataset."""
179
+ if not os.path.exists(out_lmdb_dir):
180
+ os.makedirs(out_lmdb_dir)
181
+
182
+ # Initialize the multi-LMDB manager
183
+ lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024)
184
+
185
+ self.n_out_samples = 0
186
+ n_filtered_out = defaultdict(int)
187
+
188
+ for index, file_name in self.selected_file.iterrows():
189
+ f_name = file_name["id"]
190
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
191
+ pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext)
192
+
193
+ # Process data
194
+ data = self._process_file_data(f_name, pose_file, ext)
195
+ if data is None:
196
+ continue
197
+
198
+ # Sample from clip
199
+ filtered_result, self.n_out_samples = sample_from_clip(
200
+ lmdb_manager=lmdb_manager,
201
+ audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"),
202
+ audio_each_file=data['audio'],
203
+ pose_each_file=data['pose'],
204
+ trans_each_file=data['trans'],
205
+ trans_v_each_file=data['trans_v'],
206
+ shape_each_file=data['shape'],
207
+ facial_each_file=data['facial'],
208
+ word_each_file=data['word'],
209
+ vid_each_file=data['vid'],
210
+ emo_each_file=data['emo'],
211
+ sem_each_file=data['sem'],
212
+ args=self.args,
213
+ ori_stride=self.ori_stride,
214
+ ori_length=self.ori_length,
215
+ disable_filtering=disable_filtering,
216
+ clean_first_seconds=clean_first_seconds,
217
+ clean_final_seconds=clean_final_seconds,
218
+ is_test=is_test,
219
+ n_out_samples=self.n_out_samples
220
+ )
221
+
222
+ for type_key in filtered_result:
223
+ n_filtered_out[type_key] += filtered_result[type_key]
224
+
225
+ lmdb_manager.close()
226
+
227
+ def _process_file_data(self, f_name, pose_file, ext):
228
+ """Process all data for a single file."""
229
+ data = {
230
+ 'pose': None, 'trans': None, 'trans_v': None, 'shape': None,
231
+ 'audio': None, 'facial': None, 'word': None, 'emo': None,
232
+ 'sem': None, 'vid': None
233
+ }
234
+
235
+ # Process motion data
236
+ logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue"))
237
+ if "smplx" in self.args.pose_rep:
238
+ motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep)
239
+ else:
240
+ raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.")
241
+
242
+ if motion_data is None:
243
+ return None
244
+
245
+ data.update(motion_data)
246
+
247
+ # Process speaker ID
248
+ if self.args.id_rep is not None:
249
+ speaker_id = int(f_name.split("_")[0]) - 1
250
+ data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0)
251
+ else:
252
+ data['vid'] = np.array([-1])
253
+
254
+ # Process audio if needed
255
+ if self.args.audio_rep is not None:
256
+ audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
257
+ data = process_audio_data(audio_file, self.args, data, f_name, self.selected_file)
258
+ if data is None:
259
+ return None
260
+
261
+ # Process emotion if needed
262
+ if self.args.emo_rep is not None:
263
+ data = process_emotion_data(f_name, data, self.args)
264
+ if data is None:
265
+ return None
266
+
267
+ # Process word data if needed
268
+ if self.args.word_rep is not None:
269
+ word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid"
270
+ data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file, self.lang_model)
271
+ if data is None:
272
+ return None
273
+
274
+ # Process semantic data if needed
275
+ if self.args.sem_rep is not None:
276
+ sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt"
277
+ data = process_semantic_data(sem_file, self.args, data, f_name)
278
+ if data is None:
279
+ return None
280
+
281
+ return data
282
+
283
+ def load_db_mapping(self):
284
+ """Load database mapping from file."""
285
+ mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
286
+ backup_path = os.path.join(self.preloaded_dir, "sample_db_mapping_backup.pkl")
287
+
288
+ # Check if file exists and is readable
289
+ if not os.path.exists(mapping_path):
290
+ raise FileNotFoundError(f"Mapping file not found: {mapping_path}")
291
+
292
+ # Check file size to ensure it's not empty
293
+ file_size = os.path.getsize(mapping_path)
294
+ if file_size == 0:
295
+ raise ValueError(f"Mapping file is empty: {mapping_path}")
296
+
297
+ print(f"Loading mapping file: {mapping_path} (size: {file_size} bytes)")
298
+
299
+ # Add error handling and retry logic for pickle loading
300
+ max_retries = 3
301
+ for attempt in range(max_retries):
302
+ try:
303
+ with open(mapping_path, 'rb') as f:
304
+ self.mapping_data = pickle.load(f)
305
+ print(f"Successfully loaded mapping data with {len(self.mapping_data.get('mapping', []))} samples")
306
+ break
307
+ except (EOFError, pickle.UnpicklingError) as e:
308
+ if attempt < max_retries - 1:
309
+ print(f"Warning: Failed to load pickle file (attempt {attempt + 1}/{max_retries}): {e}")
310
+ print(f"File path: {mapping_path}")
311
+
312
+ # Try backup file if main file is corrupted
313
+ if os.path.exists(backup_path) and os.path.getsize(backup_path) > 0:
314
+ print("Trying backup file...")
315
+ try:
316
+ with open(backup_path, 'rb') as f:
317
+ self.mapping_data = pickle.load(f)
318
+ print(f"Successfully loaded mapping data from backup with {len(self.mapping_data.get('mapping', []))} samples")
319
+ break
320
+ except Exception as backup_e:
321
+ print(f"Backup file also failed: {backup_e}")
322
+
323
+ print("Retrying...")
324
+ time.sleep(1) # Wait a bit before retrying
325
+ else:
326
+ print(f"Error: Failed to load pickle file after {max_retries} attempts: {e}")
327
+ print(f"File path: {mapping_path}")
328
+ print("Please check if the file is corrupted or incomplete.")
329
+ print("You may need to regenerate the cache files.")
330
+ raise
331
+
332
+ # Update paths from test to test_clip if needed
333
+ if self.loader_type == "test" and self.args.test_clip:
334
+ updated_paths = []
335
+ for path in self.mapping_data['db_paths']:
336
+ updated_path = path.replace("test/", "test_clip/")
337
+ updated_paths.append(updated_path)
338
+ self.mapping_data['db_paths'] = updated_paths
339
+
340
+ # In DDP mode, avoid modifying shared files to prevent race conditions
341
+ # Instead, just update the in-memory data
342
+ print(f"Updated test paths for test_clip mode (avoiding file modification in DDP)")
343
+
344
+ self.n_samples = len(self.mapping_data['mapping'])
345
+
346
+ def get_lmdb_env(self, db_idx):
347
+ """Get LMDB environment for given database index."""
348
+ if db_idx not in self.lmdb_envs:
349
+ db_path = self.mapping_data['db_paths'][db_idx]
350
+ self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False)
351
+ return self.lmdb_envs[db_idx]
352
+
353
+ def __len__(self):
354
+ """Return the total number of samples in the dataset."""
355
+ return self.n_samples
356
+
357
+ def __getitem__(self, idx):
358
+ """Get a single sample from the dataset."""
359
+ db_idx = self.mapping_data['mapping'][idx]
360
+ lmdb_env = self.get_lmdb_env(db_idx)
361
+
362
+ with lmdb_env.begin(write=False) as txn:
363
+ key = "{:008d}".format(idx).encode("ascii")
364
+ sample = txn.get(key)
365
+ sample = pickle.loads(sample)
366
+
367
+ tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v, audio_name = sample
368
+
369
+ # Convert data to tensors with appropriate types
370
+ processed_data = self._convert_to_tensors(
371
+ tar_pose, in_audio, in_facial, in_shape, in_word,
372
+ emo, sem, vid, trans, trans_v
373
+ )
374
+
375
+ processed_data['audio_name'] = audio_name
376
+ return processed_data
377
+
378
+ def _convert_to_tensors(self, tar_pose, in_audio, in_facial, in_shape, in_word,
379
+ emo, sem, vid, trans, trans_v):
380
+ """Convert numpy arrays to tensors with appropriate types."""
381
+ data = {
382
+ 'emo': torch.from_numpy(emo).int(),
383
+ 'sem': torch.from_numpy(sem).float(),
384
+ 'audio_onset': torch.from_numpy(in_audio).float(),
385
+ 'word': torch.from_numpy(in_word).int()
386
+ }
387
+
388
+ if self.loader_type == "test":
389
+ data.update({
390
+ 'pose': torch.from_numpy(tar_pose).float(),
391
+ 'trans': torch.from_numpy(trans).float(),
392
+ 'trans_v': torch.from_numpy(trans_v).float(),
393
+ 'facial': torch.from_numpy(in_facial).float(),
394
+ 'id': torch.from_numpy(vid).float(),
395
+ 'beta': torch.from_numpy(in_shape).float()
396
+ })
397
+ else:
398
+ data.update({
399
+ 'pose': torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float(),
400
+ 'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(),
401
+ 'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(),
402
+ 'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(),
403
+ 'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(),
404
+ 'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
405
+ })
406
+
407
+ return data
408
+
409
+ def regenerate_cache_if_corrupted(self):
410
+ """Regenerate cache if the pickle file is corrupted."""
411
+ mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
412
+
413
+ if os.path.exists(mapping_path):
414
+ try:
415
+ # Try to load the file to check if it's corrupted
416
+ with open(mapping_path, 'rb') as f:
417
+ test_data = pickle.load(f)
418
+ return False # File is not corrupted
419
+ except (EOFError, pickle.UnpicklingError):
420
+ print(f"Detected corrupted pickle file: {mapping_path}")
421
+ print("Regenerating cache...")
422
+
423
+ # Remove corrupted file
424
+ os.remove(mapping_path)
425
+
426
+ # Regenerate cache
427
+ self.build_cache(self.preloaded_dir)
428
+ return True
429
+
430
+ return False
dataloaders/beat_sep_single.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import textgrid as tg
8
+ import pandas as pd
9
+ import torch
10
+ import glob
11
+ import json
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ #import pyarrow
18
+ import pickle
19
+ import librosa
20
+ import smplx
21
+
22
+ from .build_vocab import Vocab
23
+ from models.utils.wav2vec import Wav2Vec2Model
24
+ from .data_tools import joints_list
25
+ from .utils import rotation_conversions as rc
26
+ from .utils import other_tools
27
+
28
+ import torch.nn.functional as F
29
+
30
+
31
+ class _FallbackLangModel:
32
+ """Minimal vocabulary that grows on demand for demo/test mode."""
33
+
34
+ def __init__(self) -> None:
35
+ self.PAD_token = 0
36
+ self.UNK_token = 1
37
+ self._word_to_idx = {"<pad>": self.PAD_token, "<unk>": self.UNK_token}
38
+ self.word_embedding_weights = np.zeros((2, 300), dtype=np.float32)
39
+
40
+ def get_word_index(self, word: str) -> int:
41
+ if word is None:
42
+ return self.UNK_token
43
+ cleaned = word.strip().lower()
44
+ if not cleaned:
45
+ return self.PAD_token
46
+ return self._word_to_idx["<unk>"]
47
+
48
+
49
+ class CustomDataset(Dataset):
50
+ def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
51
+ self.audio_file_path = args.audio_file_path
52
+ self.textgrid_file_path = args.textgrid_file_path
53
+ self.default_pose_file = "./demo/examples/2_scott_0_1_1.npz"
54
+
55
+ self.args = args
56
+ self.loader_type = loader_type
57
+
58
+ self.rank = 0
59
+ self.ori_stride = self.args.stride
60
+ self.ori_length = self.args.pose_length
61
+ self.alignment = [0,0] # for trinity
62
+
63
+ self.ori_joint_list = joints_list[self.args.ori_joints]
64
+ self.tar_joint_list = joints_list[self.args.tar_joints]
65
+ if 'smplx' in self.args.pose_rep:
66
+ self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
67
+ self.joints = len(list(self.tar_joint_list.keys()))
68
+ for joint_name in self.tar_joint_list:
69
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
70
+ else:
71
+ self.joints = len(list(self.ori_joint_list.keys()))+1
72
+ self.joint_mask = np.zeros(self.joints*3)
73
+ for joint_name in self.tar_joint_list:
74
+ if joint_name == "Hips":
75
+ self.joint_mask[3:6] = 1
76
+ else:
77
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
78
+ # select trainable joints
79
+ self.smplx = smplx.create(
80
+ self.args.data_path_1+"smplx_models/",
81
+ model_type='smplx',
82
+ gender='NEUTRAL_2020',
83
+ use_face_contour=False,
84
+ num_betas=300,
85
+ num_expression_coeffs=100,
86
+ ext='npz',
87
+ use_pca=False,
88
+ ).eval()
89
+
90
+ if loader_type == 'test':
91
+ # In demo/test mode, skip dataset CSV and use provided paths
92
+ self.selected_file = pd.DataFrame([{
93
+ 'id': 'demo_0',
94
+ 'audio_path': self.args.audio_file_path or './demo/examples/2_scott_0_1_1.wav',
95
+ 'textgrid_path': self.args.textgrid_file_path or None,
96
+ 'pose_path': self.default_pose_file,
97
+ }])
98
+ else:
99
+ split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
100
+ self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
101
+ if args.additional_data and loader_type == 'train':
102
+ split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
103
+ self.selected_file = pd.concat([self.selected_file, split_b])
104
+ if self.selected_file.empty:
105
+ logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
106
+ self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
107
+ self.selected_file = self.selected_file.iloc[0:8]
108
+ self.data_dir = args.data_path
109
+
110
+ if loader_type == "test":
111
+ self.args.multi_length_training = [1.0]
112
+ self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
113
+ self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
114
+ if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
115
+ self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
116
+
117
+ if args.word_rep is not None:
118
+ vocab_path = f"{args.data_path}weights/vocab.pkl"
119
+ if loader_type == 'test':
120
+ logger.info("Instantiating fallback vocabulary for test loader")
121
+ self.lang_model = _FallbackLangModel()
122
+ elif os.path.exists(vocab_path):
123
+ with open(vocab_path, 'rb') as f:
124
+ self.lang_model = pickle.load(f)
125
+ else:
126
+ logger.warning(f"vocab.pkl not found at {vocab_path}, using fallback vocabulary")
127
+ self.lang_model = _FallbackLangModel()
128
+ else:
129
+ self.lang_model = None
130
+
131
+ preloaded_dir = self.args.tmp_dir+'/' + loader_type + f"/{args.pose_rep}_cache"
132
+
133
+ if self.args.beat_align and loader_type != 'test':
134
+ if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
135
+ self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
136
+ self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
137
+ else:
138
+ self.avg_vel = None
139
+
140
+ if build_cache and self.rank == 0:
141
+ self.build_cache(preloaded_dir)
142
+ self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
143
+ with self.lmdb_env.begin() as txn:
144
+ self.n_samples = txn.stat()["entries"]
145
+
146
+
147
+
148
+
149
+ def calculate_mean_velocity(self, save_path):
150
+ # Stub for demo mode: write zero velocity to avoid heavy computation
151
+ avg_vel = np.zeros(55)
152
+ np.save(save_path, avg_vel)
153
+
154
+
155
+ def build_cache(self, preloaded_dir):
156
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
157
+ logger.info("Reading data '{}'...".format(self.data_dir))
158
+ logger.info("Creating the dataset cache...")
159
+ if self.args.new_cache:
160
+ if os.path.exists(preloaded_dir):
161
+ shutil.rmtree(preloaded_dir)
162
+ if os.path.exists(preloaded_dir):
163
+ logger.info("Found the cache {}".format(preloaded_dir))
164
+ elif self.loader_type == "test":
165
+ self.cache_generation(
166
+ preloaded_dir, True,
167
+ 0, 0,
168
+ is_test=True)
169
+ else:
170
+ self.cache_generation(
171
+ preloaded_dir, self.args.disable_filtering,
172
+ self.args.clean_first_seconds, self.args.clean_final_seconds,
173
+ is_test=False)
174
+
175
+ def __len__(self):
176
+ return self.n_samples
177
+
178
+
179
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
180
+ # if "wav2vec2" in self.args.audio_rep:
181
+ # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
182
+ # self.wav2vec_model.feature_extractor._freeze_parameters()
183
+ # self.wav2vec_model = self.wav2vec_model.cuda()
184
+ # self.wav2vec_model.eval()
185
+
186
+ self.n_out_samples = 0
187
+ # create db for samples
188
+ if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
189
+ dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 500))# 500G
190
+ n_filtered_out = defaultdict(int)
191
+
192
+
193
+ #f_name = file_name["id"]
194
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
195
+ pose_file = self.default_pose_file
196
+ pose_each_file = []
197
+ trans_each_file = []
198
+ trans_v_each_file = []
199
+ shape_each_file = []
200
+ audio_each_file = []
201
+ facial_each_file = []
202
+ word_each_file = []
203
+ emo_each_file = []
204
+ sem_each_file = []
205
+ vid_each_file = []
206
+ id_pose = "tmp" #1_wayne_0_1_1
207
+
208
+ logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
209
+ if "smplx" in self.args.pose_rep:
210
+ pose_data = np.load(pose_file, allow_pickle=True)
211
+ assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
212
+ stride = int(30/self.args.pose_fps)
213
+ pose_each_file = pose_data["poses"][::stride]
214
+ trans_each_file = pose_data["trans"][::stride]
215
+ trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0]
216
+ trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2]
217
+ trans_v_each_file = np.zeros_like(trans_each_file)
218
+ trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0]
219
+ trans_v_each_file[0,0] = trans_v_each_file[1,0]
220
+ trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2]
221
+ trans_v_each_file[0,2] = trans_v_each_file[1,2]
222
+ trans_v_each_file[:,1] = trans_each_file[:,1]
223
+ shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
224
+
225
+ assert self.args.pose_fps == 30, "should 30"
226
+ m_data = np.load(pose_file, allow_pickle=True)
227
+ betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
228
+ n, c = poses.shape[0], poses.shape[1]
229
+ betas = betas.reshape(1, 300)
230
+ betas = np.tile(betas, (n, 1))
231
+ betas = torch.from_numpy(betas).float()
232
+ poses = torch.from_numpy(poses.reshape(n, c)).float()
233
+ exps = torch.from_numpy(exps.reshape(n, 100)).float()
234
+ trans = torch.from_numpy(trans.reshape(n, 3)).float()
235
+ max_length = 128 # δΈΊδ»€δΉˆθΏ™ι‡Œιœ€θ¦δΈ€δΈͺmax_length
236
+ s, r = n//max_length, n%max_length
237
+ #print(n, s, r)
238
+ all_tensor = []
239
+ for i in range(s):
240
+ with torch.no_grad():
241
+ joints = self.smplx(
242
+ betas=betas[i*max_length:(i+1)*max_length],
243
+ transl=trans[i*max_length:(i+1)*max_length],
244
+ expression=exps[i*max_length:(i+1)*max_length],
245
+ jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
246
+ global_orient=poses[i*max_length:(i+1)*max_length,:3],
247
+ body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
248
+ left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
249
+ right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
250
+ return_verts=True,
251
+ return_joints=True,
252
+ leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
253
+ reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
254
+ )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu()
255
+ all_tensor.append(joints)
256
+ if r != 0:
257
+ with torch.no_grad():
258
+ joints = self.smplx(
259
+ betas=betas[s*max_length:s*max_length+r],
260
+ transl=trans[s*max_length:s*max_length+r],
261
+ expression=exps[s*max_length:s*max_length+r],
262
+ jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
263
+ global_orient=poses[s*max_length:s*max_length+r,:3],
264
+ body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
265
+ left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
266
+ right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
267
+ return_verts=True,
268
+ return_joints=True,
269
+ leye_pose=poses[s*max_length:s*max_length+r, 69:72],
270
+ reye_pose=poses[s*max_length:s*max_length+r, 72:75],
271
+ )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu()
272
+ all_tensor.append(joints)
273
+ joints = torch.cat(all_tensor, axis=0) # all, 4, 3
274
+ # print(joints.shape)
275
+ feetv = torch.zeros(joints.shape[1], joints.shape[0])
276
+ joints = joints.permute(1, 0, 2)
277
+ #print(joints.shape, feetv.shape)
278
+ feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1)
279
+ #print(feetv.shape)
280
+ contacts = (feetv < 0.01).numpy().astype(float)
281
+ # print(contacts.shape, contacts)
282
+ contacts = contacts.transpose(1, 0)
283
+ pose_each_file = pose_each_file * self.joint_mask
284
+ pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
285
+ pose_each_file = np.concatenate([pose_each_file, contacts], axis=1)
286
+ # print(pose_each_file.shape)
287
+
288
+
289
+ if self.args.facial_rep is not None:
290
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
291
+ facial_each_file = pose_data["expressions"][::stride]
292
+ if self.args.facial_norm:
293
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
294
+
295
+ if self.args.id_rep is not None:
296
+ vid_each_file = np.repeat(np.array(int(999)-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
297
+
298
+ if self.args.audio_rep is not None:
299
+ logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
300
+ audio_file = self.audio_file_path
301
+ if not os.path.exists(audio_file):
302
+ logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
303
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
304
+
305
+ audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy")
306
+
307
+ if self.args.audio_rep == "onset+amplitude":
308
+ audio_each_file, sr = librosa.load(audio_file)
309
+ audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
310
+ from numpy.lib import stride_tricks
311
+ frame_length = 1024
312
+ # hop_length = 512
313
+ shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
314
+ strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
315
+ rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
316
+ amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
317
+ # pad the last frame_length-1 samples
318
+ amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
319
+ audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
320
+ onset_array = np.zeros(len(audio_each_file), dtype=float)
321
+ onset_array[audio_onset_f] = 1.0
322
+ # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
323
+ audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
324
+
325
+
326
+ elif self.args.audio_rep == "mfcc":
327
+ audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps))
328
+ audio_each_file = audio_each_file.transpose(1, 0)
329
+ # print(audio_each_file.shape, pose_each_file.shape)
330
+ if self.args.audio_norm and self.args.audio_rep == "wave16k":
331
+ audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
332
+
333
+ time_offset = 0
334
+ if self.args.word_rep is not None:
335
+ logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
336
+ word_file = self.textgrid_file_path
337
+ if not os.path.exists(word_file):
338
+ logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
339
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
340
+ word_save_path = f"{self.data_dir}{self.args.t_pre_encoder}/{id_pose}.npy"
341
+
342
+ tgrid = tg.TextGrid.fromFile(word_file)
343
+
344
+ for i in range(pose_each_file.shape[0]):
345
+ found_flag = False
346
+ current_time = i/self.args.pose_fps + time_offset
347
+ j_last = 0
348
+ for j, word in enumerate(tgrid[0]):
349
+ word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
350
+ if word_s<=current_time and current_time<=word_e:
351
+ if word_n == " ":
352
+ word_each_file.append(self.lang_model.PAD_token)
353
+ else:
354
+ word_each_file.append(self.lang_model.get_word_index(word_n))
355
+ found_flag = True
356
+ j_last = j
357
+ break
358
+ else: continue
359
+ if not found_flag:
360
+ word_each_file.append(self.lang_model.UNK_token)
361
+ word_each_file = np.array(word_each_file)
362
+
363
+
364
+
365
+ if self.args.emo_rep is not None:
366
+ logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
367
+ rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
368
+ if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
369
+ if start >= 1 and start <= 64:
370
+ score = 0
371
+ elif start >= 65 and start <= 72:
372
+ score = 1
373
+ elif start >= 73 and start <= 80:
374
+ score = 2
375
+ elif start >= 81 and start <= 86:
376
+ score = 3
377
+ elif start >= 87 and start <= 94:
378
+ score = 4
379
+ elif start >= 95 and start <= 102:
380
+ score = 5
381
+ elif start >= 103 and start <= 110:
382
+ score = 6
383
+ elif start >= 111 and start <= 118:
384
+ score = 7
385
+ else: pass
386
+ else:
387
+ # you may denote as unknown in the future
388
+ score = 0
389
+ emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
390
+ #print(emo_each_file)
391
+
392
+ if self.args.sem_rep is not None:
393
+ logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
394
+ sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
395
+ sem_all = pd.read_csv(sem_file,
396
+ sep='\t',
397
+ names=["name", "start_time", "end_time", "duration", "score", "keywords"])
398
+ # we adopt motion-level semantic score here.
399
+ for i in range(pose_each_file.shape[0]):
400
+ found_flag = False
401
+ for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
402
+ current_time = i/self.args.pose_fps + time_offset
403
+ if start<=current_time and current_time<=end:
404
+ sem_each_file.append(score)
405
+ found_flag=True
406
+ break
407
+ else: continue
408
+ if not found_flag: sem_each_file.append(0.)
409
+ sem_each_file = np.array(sem_each_file)
410
+ #print(sem_each_file)
411
+
412
+ filtered_result = self._sample_from_clip(
413
+ dst_lmdb_env,
414
+ audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file,
415
+ vid_each_file, emo_each_file, sem_each_file,
416
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
417
+ )
418
+ for type in filtered_result.keys():
419
+ n_filtered_out[type] += filtered_result[type]
420
+
421
+
422
+
423
+
424
+ #### ---------for_end------------ ####
425
+ with dst_lmdb_env.begin() as txn:
426
+ logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
427
+ n_total_filtered = 0
428
+ for type, n_filtered in n_filtered_out.items():
429
+ logger.info("{}: {}".format(type, n_filtered))
430
+ n_total_filtered += n_filtered
431
+ logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
432
+ n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
433
+ dst_lmdb_env.sync()
434
+ dst_lmdb_env.close()
435
+
436
+ def _sample_from_clip(
437
+ self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file,
438
+ vid_each_file, emo_each_file, sem_each_file,
439
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
440
+ ):
441
+ """
442
+ for data cleaning, we ignore the data for first and final n s
443
+ for test, we return all data
444
+ """
445
+ # audio_start = int(self.alignment[0] * self.args.audio_fps)
446
+ # pose_start = int(self.alignment[1] * self.args.pose_fps)
447
+ #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
448
+ # audio_each_file = audio_each_file[audio_start:]
449
+ # pose_each_file = pose_each_file[pose_start:]
450
+ # trans_each_file =
451
+ #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
452
+ #print(pose_each_file.shape)
453
+ round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
454
+ #print(round_seconds_skeleton)
455
+ if audio_each_file is not None:
456
+ if self.args.audio_rep != "wave16k":
457
+ round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
458
+ elif self.args.audio_rep == "mfcc":
459
+ round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps
460
+ else:
461
+ round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr
462
+ if facial_each_file is not None:
463
+ round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
464
+ logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
465
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
466
+ max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
467
+ if round_seconds_skeleton != max_round:
468
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
469
+ else:
470
+ logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
471
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
472
+ max_round = max(round_seconds_audio, round_seconds_skeleton)
473
+ if round_seconds_skeleton != max_round:
474
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
475
+
476
+ clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
477
+ clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
478
+ clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
479
+
480
+
481
+ for ratio in self.args.multi_length_training:
482
+ if is_test:# stride = length for test
483
+ cut_length = clip_e_f_pose - clip_s_f_pose
484
+ self.args.stride = cut_length
485
+ self.max_length = cut_length
486
+ else:
487
+ self.args.stride = int(ratio*self.ori_stride)
488
+ cut_length = int(self.ori_length*ratio)
489
+
490
+ num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
491
+ logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
492
+ logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
493
+
494
+ if audio_each_file is not None:
495
+ audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
496
+ """
497
+ for audio sr = 16000, fps = 15, pose_length = 34,
498
+ audio short length = 36266.7 -> 36266
499
+ this error is fine.
500
+ """
501
+ logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
502
+
503
+ n_filtered_out = defaultdict(int)
504
+ sample_pose_list = []
505
+ sample_audio_list = []
506
+ sample_facial_list = []
507
+ sample_shape_list = []
508
+ sample_word_list = []
509
+ sample_emo_list = []
510
+ sample_sem_list = []
511
+ sample_vid_list = []
512
+ sample_trans_list = []
513
+ sample_trans_v_list = []
514
+
515
+ for i in range(num_subdivision): # cut into around 2s chip, (self npose)
516
+ start_idx = clip_s_f_pose + i * self.args.stride
517
+ fin_idx = start_idx + cut_length
518
+ sample_pose = pose_each_file[start_idx:fin_idx]
519
+
520
+ sample_trans = trans_each_file[start_idx:fin_idx]
521
+ sample_trans_v = trans_v_each_file[start_idx:fin_idx]
522
+ sample_shape = shape_each_file[start_idx:fin_idx]
523
+ # print(sample_pose.shape)
524
+ if self.args.audio_rep is not None:
525
+ audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
526
+ audio_end = audio_start + audio_short_length
527
+ sample_audio = audio_each_file[audio_start:audio_end]
528
+ else:
529
+ sample_audio = np.array([-1])
530
+ sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
531
+ sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
532
+ sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
533
+ sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
534
+ sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
535
+
536
+ if sample_pose.any() != None:
537
+ # filtering motion skeleton data
538
+ sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
539
+ is_correct_motion = (sample_pose is not None)
540
+ if is_correct_motion or disable_filtering:
541
+ sample_pose_list.append(sample_pose)
542
+ sample_audio_list.append(sample_audio)
543
+ sample_facial_list.append(sample_facial)
544
+ sample_shape_list.append(sample_shape)
545
+ sample_word_list.append(sample_word)
546
+ sample_vid_list.append(sample_vid)
547
+ sample_emo_list.append(sample_emo)
548
+ sample_sem_list.append(sample_sem)
549
+ sample_trans_list.append(sample_trans)
550
+ sample_trans_v_list.append(sample_trans_v)
551
+ else:
552
+ n_filtered_out[filtering_message] += 1
553
+
554
+ if len(sample_pose_list) > 0:
555
+ with dst_lmdb_env.begin(write=True) as txn:
556
+ for pose, audio, facial, shape, word, vid, emo, sem, trans,trans_v in zip(
557
+ sample_pose_list,
558
+ sample_audio_list,
559
+ sample_facial_list,
560
+ sample_shape_list,
561
+ sample_word_list,
562
+ sample_vid_list,
563
+ sample_emo_list,
564
+ sample_sem_list,
565
+ sample_trans_list,
566
+ sample_trans_v_list,):
567
+ k = "{:005}".format(self.n_out_samples).encode("ascii")
568
+ v = [pose, audio, facial, shape, word, emo, sem, vid, trans,trans_v]
569
+ v = pickle.dumps(v,5)
570
+ txn.put(k, v)
571
+ self.n_out_samples += 1
572
+ return n_filtered_out
573
+
574
+ def __getitem__(self, idx):
575
+ with self.lmdb_env.begin(write=False) as txn:
576
+ key = "{:005}".format(idx).encode("ascii")
577
+ sample = txn.get(key)
578
+ sample = pickle.loads(sample)
579
+ tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans,trans_v = sample
580
+ #print(in_shape)
581
+ #vid = torch.from_numpy(vid).int()
582
+ emo = torch.from_numpy(emo).int()
583
+ sem = torch.from_numpy(sem).float()
584
+ in_audio = torch.from_numpy(in_audio).float()
585
+ in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
586
+ if self.loader_type == "test":
587
+ tar_pose = torch.from_numpy(tar_pose).float()
588
+ trans = torch.from_numpy(trans).float()
589
+ trans_v = torch.from_numpy(trans_v).float()
590
+ in_facial = torch.from_numpy(in_facial).float()
591
+ vid = torch.from_numpy(vid).float()
592
+ in_shape = torch.from_numpy(in_shape).float()
593
+ else:
594
+ in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
595
+ trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
596
+ trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float()
597
+ vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
598
+ tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
599
+ in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
600
+ return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans,"trans_v":trans_v}
601
+
602
+
603
+ class MotionPreprocessor:
604
+ def __init__(self, skeletons):
605
+ self.skeletons = skeletons
606
+ #self.mean_pose = mean_pose
607
+ self.filtering_message = "PASS"
608
+
609
+ def get(self):
610
+ assert (self.skeletons is not None)
611
+
612
+ # filtering
613
+ if self.skeletons is not None:
614
+ if self.check_pose_diff():
615
+ self.skeletons = []
616
+ self.filtering_message = "pose"
617
+ # elif self.check_spine_angle():
618
+ # self.skeletons = []
619
+ # self.filtering_message = "spine angle"
620
+ # elif self.check_static_motion():
621
+ # self.skeletons = []
622
+ # self.filtering_message = "motion"
623
+
624
+ # if self.skeletons is not None:
625
+ # self.skeletons = self.skeletons.tolist()
626
+ # for i, frame in enumerate(self.skeletons):
627
+ # assert not np.isnan(self.skeletons[i]).any() # missing joints
628
+
629
+ return self.skeletons, self.filtering_message
630
+
631
+ def check_static_motion(self, verbose=True):
632
+ def get_variance(skeleton, joint_idx):
633
+ wrist_pos = skeleton[:, joint_idx]
634
+ variance = np.sum(np.var(wrist_pos, axis=0))
635
+ return variance
636
+
637
+ left_arm_var = get_variance(self.skeletons, 6)
638
+ right_arm_var = get_variance(self.skeletons, 9)
639
+
640
+ th = 0.0014 # exclude 13110
641
+ # th = 0.002 # exclude 16905
642
+ if left_arm_var < th and right_arm_var < th:
643
+ if verbose:
644
+ print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
645
+ return True
646
+ else:
647
+ if verbose:
648
+ print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
649
+ return False
650
+
651
+
652
+ def check_pose_diff(self, verbose=False):
653
+ # diff = np.abs(self.skeletons - self.mean_pose) # 186*1
654
+ # diff = np.mean(diff)
655
+
656
+ # # th = 0.017
657
+ # th = 0.02 #0.02 # exclude 3594
658
+ # if diff < th:
659
+ # if verbose:
660
+ # print("skip - check_pose_diff {:.5f}".format(diff))
661
+ # return True
662
+ # # th = 3.5 #0.02 # exclude 3594
663
+ # # if 3.5 < diff < 5:
664
+ # # if verbose:
665
+ # # print("skip - check_pose_diff {:.5f}".format(diff))
666
+ # # return True
667
+ # else:
668
+ # if verbose:
669
+ # print("pass - check_pose_diff {:.5f}".format(diff))
670
+ return False
671
+
672
+
673
+ def check_spine_angle(self, verbose=True):
674
+ def angle_between(v1, v2):
675
+ v1_u = v1 / np.linalg.norm(v1)
676
+ v2_u = v2 / np.linalg.norm(v2)
677
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
678
+
679
+ angles = []
680
+ for i in range(self.skeletons.shape[0]):
681
+ spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
682
+ angle = angle_between(spine_vec, [0, -1, 0])
683
+ angles.append(angle)
684
+
685
+ if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
686
+ # if np.rad2deg(max(angles)) > 20: # exclude 8270
687
+ if verbose:
688
+ print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
689
+ return True
690
+ else:
691
+ if verbose:
692
+ print("pass - check_spine_angle {:.5f}".format(max(angles)))
693
+ return False
dataloaders/beat_smplx2020.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import textgrid as tg
8
+ import pandas as pd
9
+ import torch
10
+ import glob
11
+ import json
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ import pyarrow
18
+ import librosa
19
+ import smplx
20
+
21
+ from .build_vocab import Vocab
22
+ from .utils.audio_features import Wav2Vec2Model
23
+ from .data_tools import joints_list
24
+ from .utils import rotation_conversions as rc
25
+ from .utils import other_tools
26
+
27
+ class CustomDataset(Dataset):
28
+ def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
29
+ self.args = args
30
+ self.loader_type = loader_type
31
+
32
+ self.rank = dist.get_rank()
33
+ self.ori_stride = self.args.stride
34
+ self.ori_length = self.args.pose_length
35
+ self.alignment = [0,0] # for trinity
36
+
37
+ self.ori_joint_list = joints_list[self.args.ori_joints]
38
+ self.tar_joint_list = joints_list[self.args.tar_joints]
39
+ if 'smplx' in self.args.pose_rep:
40
+ self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
41
+ self.joints = len(list(self.ori_joint_list.keys()))
42
+ for joint_name in self.tar_joint_list:
43
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
44
+ else:
45
+ self.joints = len(list(self.ori_joint_list.keys()))+1
46
+ self.joint_mask = np.zeros(self.joints*3)
47
+ for joint_name in self.tar_joint_list:
48
+ if joint_name == "Hips":
49
+ self.joint_mask[3:6] = 1
50
+ else:
51
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
52
+ # select trainable joints
53
+
54
+ split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
55
+ self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
56
+ if args.additional_data and loader_type == 'train':
57
+ split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
58
+ #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
59
+ self.selected_file = pd.concat([self.selected_file, split_b])
60
+ if self.selected_file.empty:
61
+ logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
62
+ self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
63
+ self.selected_file = self.selected_file.iloc[0:8]
64
+ self.data_dir = args.data_path
65
+
66
+ if loader_type == "test":
67
+ self.args.multi_length_training = [1.0]
68
+ self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
69
+ self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
70
+ if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
71
+ self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
72
+
73
+ if args.word_rep is not None:
74
+ with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
75
+ self.lang_model = pickle.load(f)
76
+
77
+ preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
78
+ # if args.pose_norm:
79
+ # # careful for rotation vectors
80
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
81
+ # self.calculate_mean_pose()
82
+ # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy")
83
+ # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy")
84
+ # if args.audio_norm:
85
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"):
86
+ # self.calculate_mean_audio()
87
+ # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy")
88
+ # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy")
89
+ # if args.facial_norm:
90
+ # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
91
+ # self.calculate_mean_face()
92
+ # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")
93
+ # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")
94
+ if self.args.beat_align:
95
+ if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
96
+ self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
97
+ self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
98
+
99
+ if build_cache and self.rank == 0:
100
+ self.build_cache(preloaded_dir)
101
+ self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
102
+ with self.lmdb_env.begin() as txn:
103
+ self.n_samples = txn.stat()["entries"]
104
+
105
+
106
+ def calculate_mean_velocity(self, save_path):
107
+ self.smplx = smplx.create(
108
+ self.args.data_path_1+"smplx_models/",
109
+ model_type='smplx',
110
+ gender='NEUTRAL_2020',
111
+ use_face_contour=False,
112
+ num_betas=300,
113
+ num_expression_coeffs=100,
114
+ ext='npz',
115
+ use_pca=False,
116
+ ).cuda().eval()
117
+ dir_p = self.data_dir + self.args.pose_rep + "/"
118
+ all_list = []
119
+ from tqdm import tqdm
120
+ for tar in tqdm(os.listdir(dir_p)):
121
+ if tar.endswith(".npz"):
122
+ m_data = np.load(dir_p+tar, allow_pickle=True)
123
+ betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
124
+ n, c = poses.shape[0], poses.shape[1]
125
+ betas = betas.reshape(1, 300)
126
+ betas = np.tile(betas, (n, 1))
127
+ betas = torch.from_numpy(betas).cuda().float()
128
+ poses = torch.from_numpy(poses.reshape(n, c)).cuda().float()
129
+ exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float()
130
+ trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float()
131
+ max_length = 128
132
+ s, r = n//max_length, n%max_length
133
+ #print(n, s, r)
134
+ all_tensor = []
135
+ for i in range(s):
136
+ with torch.no_grad():
137
+ joints = self.smplx(
138
+ betas=betas[i*max_length:(i+1)*max_length],
139
+ transl=trans[i*max_length:(i+1)*max_length],
140
+ expression=exps[i*max_length:(i+1)*max_length],
141
+ jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
142
+ global_orient=poses[i*max_length:(i+1)*max_length,:3],
143
+ body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
144
+ left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
145
+ right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
146
+ return_verts=True,
147
+ return_joints=True,
148
+ leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
149
+ reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
150
+ )['joints'][:, :55, :].reshape(max_length, 55*3)
151
+ all_tensor.append(joints)
152
+ if r != 0:
153
+ with torch.no_grad():
154
+ joints = self.smplx(
155
+ betas=betas[s*max_length:s*max_length+r],
156
+ transl=trans[s*max_length:s*max_length+r],
157
+ expression=exps[s*max_length:s*max_length+r],
158
+ jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
159
+ global_orient=poses[s*max_length:s*max_length+r,:3],
160
+ body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
161
+ left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
162
+ right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
163
+ return_verts=True,
164
+ return_joints=True,
165
+ leye_pose=poses[s*max_length:s*max_length+r, 69:72],
166
+ reye_pose=poses[s*max_length:s*max_length+r, 72:75],
167
+ )['joints'][:, :55, :].reshape(r, 55*3)
168
+ all_tensor.append(joints)
169
+ joints = torch.cat(all_tensor, axis=0)
170
+ joints = joints.permute(1, 0)
171
+ dt = 1/30
172
+ # first steps is forward diff (t+1 - t) / dt
173
+ init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
174
+ # middle steps are second order (t+1 - t-1) / 2dt
175
+ middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
176
+ # last step is backward diff (t - t-1) / dt
177
+ final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
178
+ #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape)
179
+ vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3)
180
+ #print(vel_seq.shape)
181
+ #.permute(1, 0).reshape(n, 55, 3)
182
+ vel_seq_np = vel_seq.cpu().numpy()
183
+ vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55
184
+ all_list.append(vel_joints_np)
185
+ avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55
186
+ np.save(save_path, avg_vel)
187
+
188
+
189
+ def build_cache(self, preloaded_dir):
190
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
191
+ logger.info("Reading data '{}'...".format(self.data_dir))
192
+ logger.info("Creating the dataset cache...")
193
+ if self.args.new_cache:
194
+ if os.path.exists(preloaded_dir):
195
+ shutil.rmtree(preloaded_dir)
196
+ if os.path.exists(preloaded_dir):
197
+ logger.info("Found the cache {}".format(preloaded_dir))
198
+ elif self.loader_type == "test":
199
+ self.cache_generation(
200
+ preloaded_dir, True,
201
+ 0, 0,
202
+ is_test=True)
203
+ else:
204
+ self.cache_generation(
205
+ preloaded_dir, self.args.disable_filtering,
206
+ self.args.clean_first_seconds, self.args.clean_final_seconds,
207
+ is_test=False)
208
+
209
+ def __len__(self):
210
+ return self.n_samples
211
+
212
+
213
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
214
+ # if "wav2vec2" in self.args.audio_rep:
215
+ # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
216
+ # self.wav2vec_model.feature_extractor._freeze_parameters()
217
+ # self.wav2vec_model = self.wav2vec_model.cuda()
218
+ # self.wav2vec_model.eval()
219
+
220
+ self.n_out_samples = 0
221
+ # create db for samples
222
+ if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
223
+ dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
224
+ n_filtered_out = defaultdict(int)
225
+
226
+ for index, file_name in self.selected_file.iterrows():
227
+ f_name = file_name["id"]
228
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
229
+ pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
230
+ pose_each_file = []
231
+ trans_each_file = []
232
+ shape_each_file = []
233
+ audio_each_file = []
234
+ facial_each_file = []
235
+ word_each_file = []
236
+ emo_each_file = []
237
+ sem_each_file = []
238
+ vid_each_file = []
239
+ id_pose = f_name #1_wayne_0_1_1
240
+
241
+ logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
242
+ if "smplx" in self.args.pose_rep:
243
+ pose_data = np.load(pose_file, allow_pickle=True)
244
+ assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
245
+ stride = int(30/self.args.pose_fps)
246
+ pose_each_file = pose_data["poses"][::stride] * self.joint_mask
247
+ trans_each_file = pose_data["trans"][::stride]
248
+ shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
249
+ if self.args.facial_rep is not None:
250
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
251
+ facial_each_file = pose_data["expressions"][::stride]
252
+ if self.args.facial_norm:
253
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
254
+
255
+ else:
256
+ assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
257
+ stride = int(120/self.args.pose_fps)
258
+ with open(pose_file, "r") as pose_data:
259
+ for j, line in enumerate(pose_data.readlines()):
260
+ if j < 431: continue
261
+ if j%stride != 0:continue
262
+ data = np.fromstring(line, dtype=float, sep=" ")
263
+ rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ")
264
+ rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3)
265
+ rot_data = rot_data.numpy() * self.joint_mask
266
+
267
+ pose_each_file.append(rot_data)
268
+ trans_each_file.append(data[:3])
269
+
270
+ pose_each_file = np.array(pose_each_file)
271
+ # print(pose_each_file.shape)
272
+ trans_each_file = np.array(trans_each_file)
273
+ shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
274
+ if self.args.facial_rep is not None:
275
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
276
+ facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json")
277
+ assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
278
+ stride = int(60/self.args.pose_fps)
279
+ if not os.path.exists(facial_file):
280
+ logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #")
281
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
282
+ continue
283
+ with open(facial_file, 'r') as facial_data_file:
284
+ facial_data = json.load(facial_data_file)
285
+ for j, frame_data in enumerate(facial_data['frames']):
286
+ if j%stride != 0:continue
287
+ facial_each_file.append(frame_data['weights'])
288
+ facial_each_file = np.array(facial_each_file)
289
+ if self.args.facial_norm:
290
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
291
+
292
+ if self.args.id_rep is not None:
293
+ vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
294
+
295
+ if self.args.audio_rep is not None:
296
+ logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
297
+ audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
298
+ if not os.path.exists(audio_file):
299
+ logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
300
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
301
+ continue
302
+ audio_each_file, sr = librosa.load(audio_file)
303
+ audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
304
+ if self.args.audio_rep == "onset+amplitude":
305
+ from numpy.lib import stride_tricks
306
+ frame_length = 1024
307
+ # hop_length = 512
308
+ shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
309
+ strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
310
+ rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
311
+ amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
312
+ # pad the last frame_length-1 samples
313
+ amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
314
+ audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
315
+ onset_array = np.zeros(len(audio_each_file), dtype=float)
316
+ onset_array[audio_onset_f] = 1.0
317
+ # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
318
+ audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
319
+ elif self.args.audio_rep == "mfcc":
320
+ audio_each_file = librosa.feature.mfcc(audio_each_file, sr=self.args.audio_sr, n_mfcc=13, hop_length=int(self.args.audio_sr/self.args.audio_fps))
321
+
322
+ if self.args.audio_norm and self.args.audio_rep == "wave16k":
323
+ audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
324
+
325
+ time_offset = 0
326
+ if self.args.word_rep is not None:
327
+ logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
328
+ word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid"
329
+ if not os.path.exists(word_file):
330
+ logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
331
+ self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
332
+ continue
333
+ tgrid = tg.TextGrid.fromFile(word_file)
334
+ if self.args.t_pre_encoder == "bert":
335
+ from transformers import AutoTokenizer, BertModel
336
+ tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True)
337
+ model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval()
338
+ list_word = []
339
+ all_hidden = []
340
+ max_len = 400
341
+ last = 0
342
+ word_token_mapping = []
343
+ first = True
344
+ for i, word in enumerate(tgrid[0]):
345
+ last = i
346
+ if (i%max_len != 0) or (i==0):
347
+ if word.mark == "":
348
+ list_word.append(".")
349
+ else:
350
+ list_word.append(word.mark)
351
+ else:
352
+ max_counter = max_len
353
+ str_word = ' '.join(map(str, list_word))
354
+ if first:
355
+ global_len = 0
356
+ end = -1
357
+ offset_word = []
358
+ for k, wordvalue in enumerate(list_word):
359
+ start = end+1
360
+ end = start+len(wordvalue)
361
+ offset_word.append((start, end))
362
+ #print(offset_word)
363
+ token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
364
+ #print(token_scan)
365
+ for start, end in offset_word:
366
+ sub_mapping = []
367
+ for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
368
+ if int(start) <= int(start_t) and int(end_t) <= int(end):
369
+ #print(i+global_len)
370
+ sub_mapping.append(i+global_len)
371
+ word_token_mapping.append(sub_mapping)
372
+ #print(len(word_token_mapping))
373
+ global_len = word_token_mapping[-1][-1] + 1
374
+ list_word = []
375
+ if word.mark == "":
376
+ list_word.append(".")
377
+ else:
378
+ list_word.append(word.mark)
379
+
380
+ with torch.no_grad():
381
+ inputs = tokenizer(str_word, return_tensors="pt")
382
+ outputs = model(**inputs)
383
+ last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
384
+ all_hidden.append(last_hidden_states)
385
+
386
+ #list_word = list_word[:10]
387
+ if list_word == []:
388
+ pass
389
+ else:
390
+ if first:
391
+ global_len = 0
392
+ str_word = ' '.join(map(str, list_word))
393
+ end = -1
394
+ offset_word = []
395
+ for k, wordvalue in enumerate(list_word):
396
+ start = end+1
397
+ end = start+len(wordvalue)
398
+ offset_word.append((start, end))
399
+ #print(offset_word)
400
+ token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
401
+ #print(token_scan)
402
+ for start, end in offset_word:
403
+ sub_mapping = []
404
+ for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
405
+ if int(start) <= int(start_t) and int(end_t) <= int(end):
406
+ sub_mapping.append(i+global_len)
407
+ #print(sub_mapping)
408
+ word_token_mapping.append(sub_mapping)
409
+ #print(len(word_token_mapping))
410
+ with torch.no_grad():
411
+ inputs = tokenizer(str_word, return_tensors="pt")
412
+ outputs = model(**inputs)
413
+ last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
414
+ all_hidden.append(last_hidden_states)
415
+ last_hidden_states = np.concatenate(all_hidden, axis=0)
416
+
417
+ for i in range(pose_each_file.shape[0]):
418
+ found_flag = False
419
+ current_time = i/self.args.pose_fps + time_offset
420
+ j_last = 0
421
+ for j, word in enumerate(tgrid[0]):
422
+ word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
423
+ if word_s<=current_time and current_time<=word_e:
424
+ if self.args.word_cache and self.args.t_pre_encoder == 'bert':
425
+ mapping_index = word_token_mapping[j]
426
+ #print(mapping_index, word_s, word_e)
427
+ s_t = np.linspace(word_s, word_e, len(mapping_index)+1)
428
+ #print(s_t)
429
+ for tt, t_sep in enumerate(s_t[1:]):
430
+ if current_time <= t_sep:
431
+ #if len(mapping_index) > 1: print(mapping_index[tt])
432
+ word_each_file.append(last_hidden_states[mapping_index[tt]])
433
+ break
434
+ else:
435
+ if word_n == " ":
436
+ word_each_file.append(self.lang_model.PAD_token)
437
+ else:
438
+ word_each_file.append(self.lang_model.get_word_index(word_n))
439
+ found_flag = True
440
+ j_last = j
441
+ break
442
+ else: continue
443
+ if not found_flag:
444
+ if self.args.word_cache and self.args.t_pre_encoder == 'bert':
445
+ word_each_file.append(last_hidden_states[j_last])
446
+ else:
447
+ word_each_file.append(self.lang_model.UNK_token)
448
+ word_each_file = np.array(word_each_file)
449
+ #print(word_each_file.shape)
450
+
451
+ if self.args.emo_rep is not None:
452
+ logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
453
+ rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
454
+ if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
455
+ if start >= 1 and start <= 64:
456
+ score = 0
457
+ elif start >= 65 and start <= 72:
458
+ score = 1
459
+ elif start >= 73 and start <= 80:
460
+ score = 2
461
+ elif start >= 81 and start <= 86:
462
+ score = 3
463
+ elif start >= 87 and start <= 94:
464
+ score = 4
465
+ elif start >= 95 and start <= 102:
466
+ score = 5
467
+ elif start >= 103 and start <= 110:
468
+ score = 6
469
+ elif start >= 111 and start <= 118:
470
+ score = 7
471
+ else: pass
472
+ else:
473
+ # you may denote as unknown in the future
474
+ score = 0
475
+ emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
476
+ #print(emo_each_file)
477
+
478
+ if self.args.sem_rep is not None:
479
+ logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
480
+ sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
481
+ sem_all = pd.read_csv(sem_file,
482
+ sep='\t',
483
+ names=["name", "start_time", "end_time", "duration", "score", "keywords"])
484
+ # we adopt motion-level semantic score here.
485
+ for i in range(pose_each_file.shape[0]):
486
+ found_flag = False
487
+ for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
488
+ current_time = i/self.args.pose_fps + time_offset
489
+ if start<=current_time and current_time<=end:
490
+ sem_each_file.append(score)
491
+ found_flag=True
492
+ break
493
+ else: continue
494
+ if not found_flag: sem_each_file.append(0.)
495
+ sem_each_file = np.array(sem_each_file)
496
+ #print(sem_each_file)
497
+
498
+ filtered_result = self._sample_from_clip(
499
+ dst_lmdb_env,
500
+ audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
501
+ vid_each_file, emo_each_file, sem_each_file,
502
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
503
+ )
504
+ for type in filtered_result.keys():
505
+ n_filtered_out[type] += filtered_result[type]
506
+
507
+ with dst_lmdb_env.begin() as txn:
508
+ logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
509
+ n_total_filtered = 0
510
+ for type, n_filtered in n_filtered_out.items():
511
+ logger.info("{}: {}".format(type, n_filtered))
512
+ n_total_filtered += n_filtered
513
+ logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
514
+ n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
515
+ dst_lmdb_env.sync()
516
+ dst_lmdb_env.close()
517
+
518
+ def _sample_from_clip(
519
+ self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
520
+ vid_each_file, emo_each_file, sem_each_file,
521
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
522
+ ):
523
+ """
524
+ for data cleaning, we ignore the data for first and final n s
525
+ for test, we return all data
526
+ """
527
+ # audio_start = int(self.alignment[0] * self.args.audio_fps)
528
+ # pose_start = int(self.alignment[1] * self.args.pose_fps)
529
+ #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
530
+ # audio_each_file = audio_each_file[audio_start:]
531
+ # pose_each_file = pose_each_file[pose_start:]
532
+ # trans_each_file =
533
+ #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
534
+ #print(pose_each_file.shape)
535
+ round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
536
+ #print(round_seconds_skeleton)
537
+ if audio_each_file != []:
538
+ round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
539
+ if facial_each_file != []:
540
+ round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
541
+ logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
542
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
543
+ max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
544
+ if round_seconds_skeleton != max_round:
545
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
546
+ else:
547
+ logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
548
+ round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
549
+ max_round = max(round_seconds_audio, round_seconds_skeleton)
550
+ if round_seconds_skeleton != max_round:
551
+ logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
552
+
553
+ clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
554
+ clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
555
+ clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
556
+
557
+
558
+ for ratio in self.args.multi_length_training:
559
+ if is_test:# stride = length for test
560
+ cut_length = clip_e_f_pose - clip_s_f_pose
561
+ self.args.stride = cut_length
562
+ self.max_length = cut_length
563
+ else:
564
+ self.args.stride = int(ratio*self.ori_stride)
565
+ cut_length = int(self.ori_length*ratio)
566
+
567
+ num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
568
+ logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
569
+ logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
570
+
571
+ if audio_each_file != []:
572
+ audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
573
+ """
574
+ for audio sr = 16000, fps = 15, pose_length = 34,
575
+ audio short length = 36266.7 -> 36266
576
+ this error is fine.
577
+ """
578
+ logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
579
+
580
+ n_filtered_out = defaultdict(int)
581
+ sample_pose_list = []
582
+ sample_audio_list = []
583
+ sample_facial_list = []
584
+ sample_shape_list = []
585
+ sample_word_list = []
586
+ sample_emo_list = []
587
+ sample_sem_list = []
588
+ sample_vid_list = []
589
+ sample_trans_list = []
590
+
591
+ for i in range(num_subdivision): # cut into around 2s chip, (self npose)
592
+ start_idx = clip_s_f_pose + i * self.args.stride
593
+ fin_idx = start_idx + cut_length
594
+ sample_pose = pose_each_file[start_idx:fin_idx]
595
+ sample_trans = trans_each_file[start_idx:fin_idx]
596
+ sample_shape = shape_each_file[start_idx:fin_idx]
597
+ # print(sample_pose.shape)
598
+ if self.args.audio_rep is not None:
599
+ audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
600
+ audio_end = audio_start + audio_short_length
601
+ sample_audio = audio_each_file[audio_start:audio_end]
602
+ else:
603
+ sample_audio = np.array([-1])
604
+ sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
605
+ sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
606
+ sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
607
+ sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
608
+ sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
609
+
610
+ if sample_pose.any() != None:
611
+ # filtering motion skeleton data
612
+ sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
613
+ is_correct_motion = (sample_pose != [])
614
+ if is_correct_motion or disable_filtering:
615
+ sample_pose_list.append(sample_pose)
616
+ sample_audio_list.append(sample_audio)
617
+ sample_facial_list.append(sample_facial)
618
+ sample_shape_list.append(sample_shape)
619
+ sample_word_list.append(sample_word)
620
+ sample_vid_list.append(sample_vid)
621
+ sample_emo_list.append(sample_emo)
622
+ sample_sem_list.append(sample_sem)
623
+ sample_trans_list.append(sample_trans)
624
+ else:
625
+ n_filtered_out[filtering_message] += 1
626
+
627
+ if len(sample_pose_list) > 0:
628
+ with dst_lmdb_env.begin(write=True) as txn:
629
+ for pose, audio, facial, shape, word, vid, emo, sem, trans in zip(
630
+ sample_pose_list,
631
+ sample_audio_list,
632
+ sample_facial_list,
633
+ sample_shape_list,
634
+ sample_word_list,
635
+ sample_vid_list,
636
+ sample_emo_list,
637
+ sample_sem_list,
638
+ sample_trans_list,):
639
+ k = "{:005}".format(self.n_out_samples).encode("ascii")
640
+ v = [pose, audio, facial, shape, word, emo, sem, vid, trans]
641
+ v = pyarrow.serialize(v).to_buffer()
642
+ txn.put(k, v)
643
+ self.n_out_samples += 1
644
+ return n_filtered_out
645
+
646
+ def __getitem__(self, idx):
647
+ with self.lmdb_env.begin(write=False) as txn:
648
+ key = "{:005}".format(idx).encode("ascii")
649
+ sample = txn.get(key)
650
+ sample = pyarrow.deserialize(sample)
651
+ tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample
652
+ #print(in_shape)
653
+ #vid = torch.from_numpy(vid).int()
654
+ emo = torch.from_numpy(emo).int()
655
+ sem = torch.from_numpy(sem).float()
656
+ in_audio = torch.from_numpy(in_audio).float()
657
+ in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
658
+ if self.loader_type == "test":
659
+ tar_pose = torch.from_numpy(tar_pose).float()
660
+ trans = torch.from_numpy(trans).float()
661
+ in_facial = torch.from_numpy(in_facial).float()
662
+ vid = torch.from_numpy(vid).float()
663
+ in_shape = torch.from_numpy(in_shape).float()
664
+ else:
665
+ in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
666
+ trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
667
+ vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
668
+ tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
669
+ in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
670
+ return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans}
671
+
672
+
673
+ class MotionPreprocessor:
674
+ def __init__(self, skeletons):
675
+ self.skeletons = skeletons
676
+ #self.mean_pose = mean_pose
677
+ self.filtering_message = "PASS"
678
+
679
+ def get(self):
680
+ assert (self.skeletons is not None)
681
+
682
+ # filtering
683
+ if self.skeletons != []:
684
+ if self.check_pose_diff():
685
+ self.skeletons = []
686
+ self.filtering_message = "pose"
687
+ # elif self.check_spine_angle():
688
+ # self.skeletons = []
689
+ # self.filtering_message = "spine angle"
690
+ # elif self.check_static_motion():
691
+ # self.skeletons = []
692
+ # self.filtering_message = "motion"
693
+
694
+ # if self.skeletons != []:
695
+ # self.skeletons = self.skeletons.tolist()
696
+ # for i, frame in enumerate(self.skeletons):
697
+ # assert not np.isnan(self.skeletons[i]).any() # missing joints
698
+
699
+ return self.skeletons, self.filtering_message
700
+
701
+ def check_static_motion(self, verbose=True):
702
+ def get_variance(skeleton, joint_idx):
703
+ wrist_pos = skeleton[:, joint_idx]
704
+ variance = np.sum(np.var(wrist_pos, axis=0))
705
+ return variance
706
+
707
+ left_arm_var = get_variance(self.skeletons, 6)
708
+ right_arm_var = get_variance(self.skeletons, 9)
709
+
710
+ th = 0.0014 # exclude 13110
711
+ # th = 0.002 # exclude 16905
712
+ if left_arm_var < th and right_arm_var < th:
713
+ if verbose:
714
+ print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
715
+ return True
716
+ else:
717
+ if verbose:
718
+ print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
719
+ return False
720
+
721
+
722
+ def check_pose_diff(self, verbose=False):
723
+ # diff = np.abs(self.skeletons - self.mean_pose) # 186*1
724
+ # diff = np.mean(diff)
725
+
726
+ # # th = 0.017
727
+ # th = 0.02 #0.02 # exclude 3594
728
+ # if diff < th:
729
+ # if verbose:
730
+ # print("skip - check_pose_diff {:.5f}".format(diff))
731
+ # return True
732
+ # # th = 3.5 #0.02 # exclude 3594
733
+ # # if 3.5 < diff < 5:
734
+ # # if verbose:
735
+ # # print("skip - check_pose_diff {:.5f}".format(diff))
736
+ # # return True
737
+ # else:
738
+ # if verbose:
739
+ # print("pass - check_pose_diff {:.5f}".format(diff))
740
+ return False
741
+
742
+
743
+ def check_spine_angle(self, verbose=True):
744
+ def angle_between(v1, v2):
745
+ v1_u = v1 / np.linalg.norm(v1)
746
+ v2_u = v2 / np.linalg.norm(v2)
747
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
748
+
749
+ angles = []
750
+ for i in range(self.skeletons.shape[0]):
751
+ spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
752
+ angle = angle_between(spine_vec, [0, -1, 0])
753
+ angles.append(angle)
754
+
755
+ if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
756
+ # if np.rad2deg(max(angles)) > 20: # exclude 8270
757
+ if verbose:
758
+ print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
759
+ return True
760
+ else:
761
+ if verbose:
762
+ print("pass - check_spine_angle {:.5f}".format(max(angles)))
763
+ return False
dataloaders/build_vocab.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import glob
3
+ import os
4
+ import pickle
5
+ import lmdb
6
+ #import pyarrow
7
+ import fasttext
8
+ from loguru import logger
9
+ from scipy import linalg
10
+
11
+
12
+ class Vocab:
13
+ PAD_token = 0
14
+ SOS_token = 1
15
+ EOS_token = 2
16
+ UNK_token = 3
17
+
18
+ def __init__(self, name, insert_default_tokens=True):
19
+ self.name = name
20
+ self.trimmed = False
21
+ self.word_embedding_weights = None
22
+ self.reset_dictionary(insert_default_tokens)
23
+
24
+ def reset_dictionary(self, insert_default_tokens=True):
25
+ self.word2index = {}
26
+ self.word2count = {}
27
+ if insert_default_tokens:
28
+ self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>",
29
+ self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"}
30
+ else:
31
+ self.index2word = {self.UNK_token: "<UNK>"}
32
+ self.n_words = len(self.index2word) # count default tokens
33
+
34
+ def index_word(self, word):
35
+ if word not in self.word2index:
36
+ self.word2index[word] = self.n_words
37
+ self.word2count[word] = 1
38
+ self.index2word[self.n_words] = word
39
+ self.n_words += 1
40
+ else:
41
+ self.word2count[word] += 1
42
+
43
+ def add_vocab(self, other_vocab):
44
+ for word, _ in other_vocab.word2count.items():
45
+ self.index_word(word)
46
+
47
+ # remove words below a certain count threshold
48
+ def trim(self, min_count):
49
+ if self.trimmed:
50
+ return
51
+ self.trimmed = True
52
+
53
+ keep_words = []
54
+
55
+ for k, v in self.word2count.items():
56
+ if v >= min_count:
57
+ keep_words.append(k)
58
+
59
+ print(' word trimming, kept %s / %s = %.4f' % (
60
+ len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
61
+ ))
62
+
63
+ # reinitialize dictionary
64
+ self.reset_dictionary()
65
+ for word in keep_words:
66
+ self.index_word(word)
67
+
68
+ def get_word_index(self, word):
69
+ if word in self.word2index:
70
+ return self.word2index[word]
71
+ else:
72
+ return self.UNK_token
73
+
74
+ def load_word_vectors(self, pretrained_path, embedding_dim=300):
75
+ print(" loading word vectors from '{}'...".format(pretrained_path))
76
+
77
+ # initialize embeddings to random values for special words
78
+ init_sd = 1 / np.sqrt(embedding_dim)
79
+ weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
80
+ weights = weights.astype(np.float32)
81
+
82
+ # read word vectors
83
+ word_model = fasttext.load_model(pretrained_path)
84
+ for word, id in self.word2index.items():
85
+ vec = word_model.get_word_vector(word)
86
+ weights[id] = vec
87
+ self.word_embedding_weights = weights
88
+
89
+ def __get_embedding_weight(self, pretrained_path, embedding_dim=300):
90
+ """ function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """
91
+ print("Loading word embedding '{}'...".format(pretrained_path))
92
+ cache_path = pretrained_path
93
+ weights = None
94
+
95
+ # use cached file if it exists
96
+ if os.path.exists(cache_path): #
97
+ with open(cache_path, 'rb') as f:
98
+ print(' using cached result from {}'.format(cache_path))
99
+ weights = pickle.load(f)
100
+ if weights.shape != (self.n_words, embedding_dim):
101
+ logging.warning(' failed to load word embedding weights. reinitializing...')
102
+ weights = None
103
+
104
+ if weights is None:
105
+ # initialize embeddings to random values for special and OOV words
106
+ init_sd = 1 / np.sqrt(embedding_dim)
107
+ weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
108
+ weights = weights.astype(np.float32)
109
+
110
+ with open(pretrained_path, encoding="utf-8", mode="r") as textFile:
111
+ num_embedded_words = 0
112
+ for line_raw in textFile:
113
+ # extract the word, and embeddings vector
114
+ line = line_raw.split()
115
+ try:
116
+ word, vector = (line[0], np.array(line[1:], dtype=np.float32))
117
+ # if word == 'love': # debugging
118
+ # print(word, vector)
119
+
120
+ # if it is in our vocab, then update the corresponding weights
121
+ id = self.word2index.get(word, None)
122
+ if id is not None:
123
+ weights[id] = vector
124
+ num_embedded_words += 1
125
+ except ValueError:
126
+ print(' parsing error at {}...'.format(line_raw[:50]))
127
+ continue
128
+ print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index)))
129
+
130
+ with open(cache_path, 'wb') as f:
131
+ pickle.dump(weights, f)
132
+ return weights
133
+
134
+
135
+ def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None):
136
+ print(' building a language model...')
137
+ #if not os.path.exists(cache_path):
138
+ lang_model = Vocab(name)
139
+ print(' indexing words from {}'.format(data_path))
140
+ index_words_from_textgrid(lang_model, data_path)
141
+
142
+ if word_vec_path is not None:
143
+ lang_model.load_word_vectors(word_vec_path, feat_dim)
144
+ else:
145
+ print(' loaded from {}'.format(cache_path))
146
+ with open(cache_path, 'rb') as f:
147
+ lang_model = pickle.load(f)
148
+ if word_vec_path is None:
149
+ lang_model.word_embedding_weights = None
150
+ elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words:
151
+ logging.warning(' failed to load word embedding weights. check this')
152
+ assert False
153
+
154
+ with open(cache_path, 'wb') as f:
155
+ pickle.dump(lang_model, f)
156
+
157
+
158
+ return lang_model
159
+
160
+
161
+ def index_words(lang_model, data_path):
162
+ #index words form text
163
+ with open(data_path, "r") as f:
164
+ for line in f.readlines():
165
+ line = line.replace(",", " ")
166
+ line = line.replace(".", " ")
167
+ line = line.replace("?", " ")
168
+ line = line.replace("!", " ")
169
+ for word in line.split():
170
+ lang_model.index_word(word)
171
+ print(' indexed %d words' % lang_model.n_words)
172
+
173
+ def index_words_from_textgrid(lang_model, data_path):
174
+ import textgrid as tg
175
+ from tqdm import tqdm
176
+ #trainvaltest=os.listdir(data_path)
177
+ # for loadtype in trainvaltest:
178
+ # if "." in loadtype: continue #ignore .ipynb_checkpoints
179
+ texts = os.listdir(data_path+"/textgrid/")
180
+ #print(texts)
181
+ for textfile in tqdm(texts):
182
+ tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile)
183
+ for word in tgrid[0]:
184
+ word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
185
+ word_n = word_n.replace(",", " ")
186
+ word_n = word_n.replace(".", " ")
187
+ word_n = word_n.replace("?", " ")
188
+ word_n = word_n.replace("!", " ")
189
+ #print(word_n)
190
+ lang_model.index_word(word_n)
191
+ print(' indexed %d words' % lang_model.n_words)
192
+ print(lang_model.word2index, lang_model.word2count)
193
+
194
+ if __name__ == "__main__":
195
+ # 11195 for all, 5793 for 4 speakers
196
+ # build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300)
197
+ build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300)
198
+
199
+
dataloaders/data_tools.py ADDED
@@ -0,0 +1,1756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import glob
3
+ import os
4
+ import pickle
5
+ import lmdb
6
+ #import pyarrow
7
+ import fasttext
8
+ from loguru import logger
9
+ from scipy import linalg
10
+ from .pymo.parsers import BVHParser
11
+ from .pymo.viz_tools import *
12
+ from .pymo.preprocessing import *
13
+
14
+
15
+
16
+
17
+ # pose version fpsxx_trinity/japanese_joints(_xxx)
18
+ joints_list = {
19
+ "trinity_joints":{
20
+ 'Hips': [6,6],
21
+ 'Spine': [3,9],
22
+ 'Spine1': [3,12],
23
+ 'Spine2': [3,15],
24
+ 'Spine3': [3,18],
25
+ 'Neck': [3,21],
26
+ 'Neck1': [3,24],
27
+ 'Head': [3,27],
28
+ 'RShoulder': [3,30],
29
+ 'RArm': [3,33],
30
+ 'RArm1': [3,36],
31
+ 'RHand': [3,39],
32
+ 'RHandT1': [3,42],
33
+ 'RHandT2': [3,45],
34
+ 'RHandT3': [3,48],
35
+ 'RHandI1': [3,51],
36
+ 'RHandI2': [3,54],
37
+ 'RHandI3': [3,57],
38
+ 'RHandM1': [3,60],
39
+ 'RHandM2': [3,63],
40
+ 'RHandM3': [3,66],
41
+ 'RHandR1': [3,69],
42
+ 'RHandR2': [3,72],
43
+ 'RHandR3': [3,75],
44
+ 'RHandP1': [3,78],
45
+ 'RHandP2': [3,81],
46
+ 'RHandP3': [3,84],
47
+ 'LShoulder': [3,87],
48
+ 'LArm': [3,90],
49
+ 'LArm1': [3,93],
50
+ 'LHand': [3,96],
51
+ 'LHandT1': [3,99],
52
+ 'LHandT2': [3,102],
53
+ 'LHandT3': [3,105],
54
+ 'LHandI1': [3,108],
55
+ 'LHandI2': [3,111],
56
+ 'LHandI3': [3,114],
57
+ 'LHandM1': [3,117],
58
+ 'LHandM2': [3,120],
59
+ 'LHandM3': [3,123],
60
+ 'LHandR1': [3,126],
61
+ 'LHandR2': [3,129],
62
+ 'LHandR3': [3,132],
63
+ 'LHandP1': [3,135],
64
+ 'LHandP2': [3,138],
65
+ 'LHandP3': [3,141],
66
+ 'RUpLeg': [3,144],
67
+ 'RLeg': [3,147],
68
+ 'RFoot': [3,150],
69
+ 'RFootF': [3,153],
70
+ 'RToeBase': [3,156],
71
+ 'LUpLeg': [3,159],
72
+ 'LLeg': [3,162],
73
+ 'LFoot': [3,165],
74
+ 'LFootF': [3,168],
75
+ 'LToeBase': [3,171],},
76
+ "trinity_joints_123":{
77
+ 'Spine': 3 ,
78
+ 'Neck': 3 ,
79
+ 'Neck1': 3 ,
80
+ 'RShoulder': 3 ,
81
+ 'RArm': 3 ,
82
+ 'RArm1': 3 ,
83
+ 'RHand': 3 ,
84
+ 'RHandT1': 3 ,
85
+ 'RHandT2': 3 ,
86
+ 'RHandT3': 3 ,
87
+ 'RHandI1': 3 ,
88
+ 'RHandI2': 3 ,
89
+ 'RHandI3': 3 ,
90
+ 'RHandM1': 3 ,
91
+ 'RHandM2': 3 ,
92
+ 'RHandM3': 3 ,
93
+ 'RHandR1': 3 ,
94
+ 'RHandR2': 3 ,
95
+ 'RHandR3': 3 ,
96
+ 'RHandP1': 3 ,
97
+ 'RHandP2': 3 ,
98
+ 'RHandP3': 3 ,
99
+ 'LShoulder': 3 ,
100
+ 'LArm': 3 ,
101
+ 'LArm1': 3 ,
102
+ 'LHand': 3 ,
103
+ 'LHandT1': 3 ,
104
+ 'LHandT2': 3 ,
105
+ 'LHandT3': 3 ,
106
+ 'LHandI1': 3 ,
107
+ 'LHandI2': 3 ,
108
+ 'LHandI3': 3 ,
109
+ 'LHandM1': 3 ,
110
+ 'LHandM2': 3 ,
111
+ 'LHandM3': 3 ,
112
+ 'LHandR1': 3 ,
113
+ 'LHandR2': 3 ,
114
+ 'LHandR3': 3 ,
115
+ 'LHandP1': 3 ,
116
+ 'LHandP2': 3 ,
117
+ 'LHandP3': 3 ,},
118
+ "trinity_joints_168":{
119
+ 'Hips': 3 ,
120
+ 'Spine': 3 ,
121
+ 'Spine1': 3 ,
122
+ 'Spine2': 3 ,
123
+ 'Spine3': 3 ,
124
+ 'Neck': 3 ,
125
+ 'Neck1': 3 ,
126
+ 'Head': 3 ,
127
+ 'RShoulder': 3 ,
128
+ 'RArm': 3 ,
129
+ 'RArm1': 3 ,
130
+ 'RHand': 3 ,
131
+ 'RHandT1': 3 ,
132
+ 'RHandT2': 3 ,
133
+ 'RHandT3': 3 ,
134
+ 'RHandI1': 3 ,
135
+ 'RHandI2': 3 ,
136
+ 'RHandI3': 3 ,
137
+ 'RHandM1': 3 ,
138
+ 'RHandM2': 3 ,
139
+ 'RHandM3': 3 ,
140
+ 'RHandR1': 3 ,
141
+ 'RHandR2': 3 ,
142
+ 'RHandR3': 3 ,
143
+ 'RHandP1': 3 ,
144
+ 'RHandP2': 3 ,
145
+ 'RHandP3': 3 ,
146
+ 'LShoulder': 3 ,
147
+ 'LArm': 3 ,
148
+ 'LArm1': 3 ,
149
+ 'LHand': 3 ,
150
+ 'LHandT1': 3 ,
151
+ 'LHandT2': 3 ,
152
+ 'LHandT3': 3 ,
153
+ 'LHandI1': 3 ,
154
+ 'LHandI2': 3 ,
155
+ 'LHandI3': 3 ,
156
+ 'LHandM1': 3 ,
157
+ 'LHandM2': 3 ,
158
+ 'LHandM3': 3 ,
159
+ 'LHandR1': 3 ,
160
+ 'LHandR2': 3 ,
161
+ 'LHandR3': 3 ,
162
+ 'LHandP1': 3 ,
163
+ 'LHandP2': 3 ,
164
+ 'LHandP3': 3 ,
165
+ 'RUpLeg': 3 ,
166
+ 'RLeg': 3 ,
167
+ 'RFoot': 3 ,
168
+ 'RFootF': 3 ,
169
+ 'RToeBase': 3 ,
170
+ 'LUpLeg': 3 ,
171
+ 'LLeg': 3 ,
172
+ 'LFoot': 3 ,
173
+ 'LFootF': 3 ,
174
+ 'LToeBase': 3 ,},
175
+ "trinity_joints_138":{
176
+ "Hips": 3 ,
177
+ 'Spine': 3 ,
178
+ 'Spine1': 3 ,
179
+ 'Spine2': 3 ,
180
+ 'Spine3': 3 ,
181
+ 'Neck': 3 ,
182
+ 'Neck1': 3 ,
183
+ 'Head': 3 ,
184
+ 'RShoulder': 3 ,
185
+ 'RArm': 3 ,
186
+ 'RArm1': 3 ,
187
+ 'RHand': 3 ,
188
+ 'RHandT1': 3 ,
189
+ 'RHandT2': 3 ,
190
+ 'RHandT3': 3 ,
191
+ 'RHandI1': 3 ,
192
+ 'RHandI2': 3 ,
193
+ 'RHandI3': 3 ,
194
+ 'RHandM1': 3 ,
195
+ 'RHandM2': 3 ,
196
+ 'RHandM3': 3 ,
197
+ 'RHandR1': 3 ,
198
+ 'RHandR2': 3 ,
199
+ 'RHandR3': 3 ,
200
+ 'RHandP1': 3 ,
201
+ 'RHandP2': 3 ,
202
+ 'RHandP3': 3 ,
203
+ 'LShoulder': 3 ,
204
+ 'LArm': 3 ,
205
+ 'LArm1': 3 ,
206
+ 'LHand': 3 ,
207
+ 'LHandT1': 3 ,
208
+ 'LHandT2': 3 ,
209
+ 'LHandT3': 3 ,
210
+ 'LHandI1': 3 ,
211
+ 'LHandI2': 3 ,
212
+ 'LHandI3': 3 ,
213
+ 'LHandM1': 3 ,
214
+ 'LHandM2': 3 ,
215
+ 'LHandM3': 3 ,
216
+ 'LHandR1': 3 ,
217
+ 'LHandR2': 3 ,
218
+ 'LHandR3': 3 ,
219
+ 'LHandP1': 3 ,
220
+ 'LHandP2': 3 ,
221
+ 'LHandP3': 3 ,},
222
+ "beat_smplx_joints": {
223
+ 'pelvis': [3,3],
224
+ 'left_hip': [3,6],
225
+ 'right_hip': [3,9],
226
+ 'spine1': [3,12],
227
+ 'left_knee': [3,15],
228
+ 'right_knee': [3,18],
229
+ 'spine2': [3,21],
230
+ 'left_ankle': [3,24],
231
+ 'right_ankle': [3,27],
232
+
233
+ 'spine3': [3,30],
234
+ 'left_foot': [3,33],
235
+ 'right_foot': [3,36],
236
+ 'neck': [3,39],
237
+ 'left_collar': [3,42],
238
+ 'right_collar': [3,45],
239
+ 'head': [3,48],
240
+ 'left_shoulder': [3,51],
241
+
242
+ 'right_shoulder': [3,54],
243
+ 'left_elbow': [3,57],
244
+ 'right_elbow': [3,60],
245
+ 'left_wrist': [3,63],
246
+ 'right_wrist': [3,66],
247
+
248
+ 'jaw': [3,69],
249
+ 'left_eye_smplhf': [3,72],
250
+ 'right_eye_smplhf': [3,75],
251
+ 'left_index1': [3,78],
252
+ 'left_index2': [3,81],
253
+
254
+ 'left_index3': [3,84],
255
+ 'left_middle1': [3,87],
256
+ 'left_middle2': [3,90],
257
+ 'left_middle3': [3,93],
258
+ 'left_pinky1': [3,96],
259
+
260
+ 'left_pinky2': [3,99],
261
+ 'left_pinky3': [3,102],
262
+ 'left_ring1': [3,105],
263
+ 'left_ring2': [3,108],
264
+
265
+ 'left_ring3': [3,111],
266
+ 'left_thumb1': [3,114],
267
+ 'left_thumb2': [3,117],
268
+ 'left_thumb3': [3,120],
269
+ 'right_index1': [3,123],
270
+ 'right_index2': [3,126],
271
+ 'right_index3': [3,129],
272
+ 'right_middle1': [3,132],
273
+
274
+ 'right_middle2': [3,135],
275
+ 'right_middle3': [3,138],
276
+ 'right_pinky1': [3,141],
277
+ 'right_pinky2': [3,144],
278
+ 'right_pinky3': [3,147],
279
+
280
+ 'right_ring1': [3,150],
281
+ 'right_ring2': [3,153],
282
+ 'right_ring3': [3,156],
283
+ 'right_thumb1': [3,159],
284
+ 'right_thumb2': [3,162],
285
+ 'right_thumb3': [3,165],
286
+
287
+ # 'nose': [3,168],
288
+ # 'right_eye': [3,171],
289
+ # 'left_eye': [3,174],
290
+ # 'right_ear': [3,177],
291
+
292
+ # 'left_ear': [3,180],
293
+ # 'left_big_toe': [3,183],
294
+ # 'left_small_toe': [3,186],
295
+ # 'left_heel': [3,189],
296
+
297
+ # 'right_big_toe': [3,192],
298
+ # 'right_small_toe': [3,195],
299
+ # 'right_heel': [3,198],
300
+ # 'left_thumb': [3,201],
301
+ # 'left_index': [3,204],
302
+ # 'left_middle': [3,207],
303
+
304
+ # 'left_ring': [3,210],
305
+ # 'left_pinky': [3,213],
306
+ # 'right_thumb': [3,216],
307
+ # 'right_index': [3,219],
308
+ # 'right_middle': [3,222],
309
+ # 'right_ring': [3,225],
310
+
311
+ # 'right_pinky': [3,228],
312
+ # 'right_eye_brow1': [3,231],
313
+ # 'right_eye_brow2': [3,234],
314
+ # 'right_eye_brow3': [3,237],
315
+
316
+ # 'right_eye_brow4': [3,240],
317
+ # 'right_eye_brow5': [3,243],
318
+ # 'left_eye_brow5': [3,246],
319
+ # 'left_eye_brow4': [3,249],
320
+
321
+ # 'left_eye_brow3': [3,252],
322
+ # 'left_eye_brow2': [3,255],
323
+ # 'left_eye_brow1': [3,258],
324
+ # 'nose1': [3,261],
325
+ # 'nose2': [3,264],
326
+ # 'nose3': [3,267],
327
+
328
+ # 'nose4': [3,270],
329
+ # 'right_nose_2': [3,273],
330
+ # 'right_nose_1': [3,276],
331
+ # 'nose_middle': [3,279],
332
+ # 'left_nose_1': [3,282],
333
+ # 'left_nose_2': [3,285],
334
+
335
+ # 'right_eye1': [3,288],
336
+ # 'right_eye2': [3,291],
337
+ # 'right_eye3': [3,294],
338
+ # 'right_eye4': [3,297],
339
+
340
+ # 'right_eye5': [3,300],
341
+ # 'right_eye6': [3,303],
342
+ # 'left_eye4': [3,306],
343
+ # 'left_eye3': [3,309],
344
+
345
+ # 'left_eye2': [3,312],
346
+ # 'left_eye1': [3,315],
347
+ # 'left_eye6': [3,318],
348
+ # 'left_eye5': [3,321],
349
+ # 'right_mouth_1': [3,324],
350
+ # 'right_mouth_2': [3,327],
351
+ # 'right_mouth_3': [3,330],
352
+ # 'mouth_top': [3,333],
353
+ # 'left_mouth_3': [3,336],
354
+ # 'left_mouth_2': [3,339],
355
+ # 'left_mouth_1': [3,342],
356
+ # 'left_mouth_5': [3,345],
357
+ # 'left_mouth_4': [3,348],
358
+ # 'mouth_bottom': [3,351],
359
+ # 'right_mouth_4': [3,354],
360
+ # 'right_mouth_5': [3,357],
361
+ # 'right_lip_1': [3,360],
362
+ # 'right_lip_2': [3,363],
363
+ # 'lip_top': [3,366],
364
+ # 'left_lip_2': [3,369],
365
+
366
+ # 'left_lip_1': [3,372],
367
+ # 'left_lip_3': [3,375],
368
+ # 'lip_bottom': [3,378],
369
+ # 'right_lip_3': [3,381],
370
+ # 'right_contour_1': [3,384],
371
+ # 'right_contour_2': [3,387],
372
+ # 'right_contour_3': [3,390],
373
+ # 'right_contour_4': [3,393],
374
+ # 'right_contour_5': [3,396],
375
+ # 'right_contour_6': [3,399],
376
+ # 'right_contour_7': [3,402],
377
+ # 'right_contour_8': [3,405],
378
+ # 'contour_middle': [3,408],
379
+ # 'left_contour_8': [3,411],
380
+ # 'left_contour_7': [3,414],
381
+ # 'left_contour_6': [3,417],
382
+ # 'left_contour_5': [3,420],
383
+ # 'left_contour_4': [3,423],
384
+ # 'left_contour_3': [3,426],
385
+ # 'left_contour_2': [3,429],
386
+ # 'left_contour_1': [3,432],
387
+ },
388
+
389
+ "beat_smplx_no_eyes": {
390
+ "pelvis":3,
391
+ "left_hip":3,
392
+ "right_hip":3,
393
+ "spine1":3,
394
+ "left_knee":3,
395
+ "right_knee":3,
396
+ "spine2":3,
397
+ "left_ankle":3,
398
+ "right_ankle":3,
399
+ "spine3":3,
400
+ "left_foot":3,
401
+ "right_foot":3,
402
+ "neck":3,
403
+ "left_collar":3,
404
+ "right_collar":3,
405
+ "head":3,
406
+ "left_shoulder":3,
407
+ "right_shoulder":3,
408
+ "left_elbow":3,
409
+ "right_elbow":3,
410
+ "left_wrist":3,
411
+ "right_wrist":3,
412
+ "jaw":3,
413
+ # "left_eye_smplhf":3,
414
+ # "right_eye_smplhf":3,
415
+ "left_index1":3,
416
+ "left_index2":3,
417
+ "left_index3":3,
418
+ "left_middle1":3,
419
+ "left_middle2":3,
420
+ "left_middle3":3,
421
+ "left_pinky1":3,
422
+ "left_pinky2":3,
423
+ "left_pinky3":3,
424
+ "left_ring1":3,
425
+ "left_ring2":3,
426
+ "left_ring3":3,
427
+ "left_thumb1":3,
428
+ "left_thumb2":3,
429
+ "left_thumb3":3,
430
+ "right_index1":3,
431
+ "right_index2":3,
432
+ "right_index3":3,
433
+ "right_middle1":3,
434
+ "right_middle2":3,
435
+ "right_middle3":3,
436
+ "right_pinky1":3,
437
+ "right_pinky2":3,
438
+ "right_pinky3":3,
439
+ "right_ring1":3,
440
+ "right_ring2":3,
441
+ "right_ring3":3,
442
+ "right_thumb1":3,
443
+ "right_thumb2":3,
444
+ "right_thumb3":3,
445
+ },
446
+
447
+ "beat_smplx_full": {
448
+ "pelvis":3,
449
+ "left_hip":3,
450
+ "right_hip":3,
451
+ "spine1":3,
452
+ "left_knee":3,
453
+ "right_knee":3,
454
+ "spine2":3,
455
+ "left_ankle":3,
456
+ "right_ankle":3,
457
+ "spine3":3,
458
+ "left_foot":3,
459
+ "right_foot":3,
460
+ "neck":3,
461
+ "left_collar":3,
462
+ "right_collar":3,
463
+ "head":3,
464
+ "left_shoulder":3,
465
+ "right_shoulder":3,
466
+ "left_elbow":3,
467
+ "right_elbow":3,
468
+ "left_wrist":3,
469
+ "right_wrist":3,
470
+ "jaw":3,
471
+ "left_eye_smplhf":3,
472
+ "right_eye_smplhf":3,
473
+ "left_index1":3,
474
+ "left_index2":3,
475
+ "left_index3":3,
476
+ "left_middle1":3,
477
+ "left_middle2":3,
478
+ "left_middle3":3,
479
+ "left_pinky1":3,
480
+ "left_pinky2":3,
481
+ "left_pinky3":3,
482
+ "left_ring1":3,
483
+ "left_ring2":3,
484
+ "left_ring3":3,
485
+ "left_thumb1":3,
486
+ "left_thumb2":3,
487
+ "left_thumb3":3,
488
+ "right_index1":3,
489
+ "right_index2":3,
490
+ "right_index3":3,
491
+ "right_middle1":3,
492
+ "right_middle2":3,
493
+ "right_middle3":3,
494
+ "right_pinky1":3,
495
+ "right_pinky2":3,
496
+ "right_pinky3":3,
497
+ "right_ring1":3,
498
+ "right_ring2":3,
499
+ "right_ring3":3,
500
+ "right_thumb1":3,
501
+ "right_thumb2":3,
502
+ "right_thumb3":3,
503
+ },
504
+
505
+ "beat_smplx_upall": {
506
+ # "pelvis":3,
507
+ # "left_hip":3,
508
+ # "right_hip":3,
509
+ "spine1":3,
510
+ # "left_knee":3,
511
+ # "right_knee":3,
512
+ "spine2":3,
513
+ # "left_ankle":3,
514
+ # "right_ankle":3,
515
+ "spine3":3,
516
+ # "left_foot":3,
517
+ # "right_foot":3,
518
+ "neck":3,
519
+ "left_collar":3,
520
+ "right_collar":3,
521
+ "head":3,
522
+ "left_shoulder":3,
523
+ "right_shoulder":3,
524
+ "left_elbow":3,
525
+ "right_elbow":3,
526
+ "left_wrist":3,
527
+ "right_wrist":3,
528
+ # "jaw":3,
529
+ # "left_eye_smplhf":3,
530
+ # "right_eye_smplhf":3,
531
+ "left_index1":3,
532
+ "left_index2":3,
533
+ "left_index3":3,
534
+ "left_middle1":3,
535
+ "left_middle2":3,
536
+ "left_middle3":3,
537
+ "left_pinky1":3,
538
+ "left_pinky2":3,
539
+ "left_pinky3":3,
540
+ "left_ring1":3,
541
+ "left_ring2":3,
542
+ "left_ring3":3,
543
+ "left_thumb1":3,
544
+ "left_thumb2":3,
545
+ "left_thumb3":3,
546
+ "right_index1":3,
547
+ "right_index2":3,
548
+ "right_index3":3,
549
+ "right_middle1":3,
550
+ "right_middle2":3,
551
+ "right_middle3":3,
552
+ "right_pinky1":3,
553
+ "right_pinky2":3,
554
+ "right_pinky3":3,
555
+ "right_ring1":3,
556
+ "right_ring2":3,
557
+ "right_ring3":3,
558
+ "right_thumb1":3,
559
+ "right_thumb2":3,
560
+ "right_thumb3":3,
561
+ },
562
+
563
+ "beat_smplx_upper": {
564
+ #"pelvis":3,
565
+ # "left_hip":3,
566
+ # "right_hip":3,
567
+ "spine1":3,
568
+ # "left_knee":3,
569
+ # "right_knee":3,
570
+ "spine2":3,
571
+ # "left_ankle":3,
572
+ # "right_ankle":3,
573
+ "spine3":3,
574
+ # "left_foot":3,
575
+ # "right_foot":3,
576
+ "neck":3,
577
+ "left_collar":3,
578
+ "right_collar":3,
579
+ "head":3,
580
+ "left_shoulder":3,
581
+ "right_shoulder":3,
582
+ "left_elbow":3,
583
+ "right_elbow":3,
584
+ "left_wrist":3,
585
+ "right_wrist":3,
586
+ # "jaw":3,
587
+ # "left_eye_smplhf":3,
588
+ # "right_eye_smplhf":3,
589
+ # "left_index1":3,
590
+ # "left_index2":3,
591
+ # "left_index3":3,
592
+ # "left_middle1":3,
593
+ # "left_middle2":3,
594
+ # "left_middle3":3,
595
+ # "left_pinky1":3,
596
+ # "left_pinky2":3,
597
+ # "left_pinky3":3,
598
+ # "left_ring1":3,
599
+ # "left_ring2":3,
600
+ # "left_ring3":3,
601
+ # "left_thumb1":3,
602
+ # "left_thumb2":3,
603
+ # "left_thumb3":3,
604
+ # "right_index1":3,
605
+ # "right_index2":3,
606
+ # "right_index3":3,
607
+ # "right_middle1":3,
608
+ # "right_middle2":3,
609
+ # "right_middle3":3,
610
+ # "right_pinky1":3,
611
+ # "right_pinky2":3,
612
+ # "right_pinky3":3,
613
+ # "right_ring1":3,
614
+ # "right_ring2":3,
615
+ # "right_ring3":3,
616
+ # "right_thumb1":3,
617
+ # "right_thumb2":3,
618
+ # "right_thumb3":3,
619
+ },
620
+
621
+ "beat_smplx_hands": {
622
+ #"pelvis":3,
623
+ # "left_hip":3,
624
+ # "right_hip":3,
625
+ # "spine1":3,
626
+ # "left_knee":3,
627
+ # "right_knee":3,
628
+ # "spine2":3,
629
+ # "left_ankle":3,
630
+ # "right_ankle":3,
631
+ # "spine3":3,
632
+ # "left_foot":3,
633
+ # "right_foot":3,
634
+ # "neck":3,
635
+ # "left_collar":3,
636
+ # "right_collar":3,
637
+ # "head":3,
638
+ # "left_shoulder":3,
639
+ # "right_shoulder":3,
640
+ # "left_elbow":3,
641
+ # "right_elbow":3,
642
+ # "left_wrist":3,
643
+ # "right_wrist":3,
644
+ # "jaw":3,
645
+ # "left_eye_smplhf":3,
646
+ # "right_eye_smplhf":3,
647
+ "left_index1":3,
648
+ "left_index2":3,
649
+ "left_index3":3,
650
+ "left_middle1":3,
651
+ "left_middle2":3,
652
+ "left_middle3":3,
653
+ "left_pinky1":3,
654
+ "left_pinky2":3,
655
+ "left_pinky3":3,
656
+ "left_ring1":3,
657
+ "left_ring2":3,
658
+ "left_ring3":3,
659
+ "left_thumb1":3,
660
+ "left_thumb2":3,
661
+ "left_thumb3":3,
662
+ "right_index1":3,
663
+ "right_index2":3,
664
+ "right_index3":3,
665
+ "right_middle1":3,
666
+ "right_middle2":3,
667
+ "right_middle3":3,
668
+ "right_pinky1":3,
669
+ "right_pinky2":3,
670
+ "right_pinky3":3,
671
+ "right_ring1":3,
672
+ "right_ring2":3,
673
+ "right_ring3":3,
674
+ "right_thumb1":3,
675
+ "right_thumb2":3,
676
+ "right_thumb3":3,
677
+ },
678
+
679
+ "beat_smplx_lower": {
680
+ "pelvis":3,
681
+ "left_hip":3,
682
+ "right_hip":3,
683
+ # "spine1":3,
684
+ "left_knee":3,
685
+ "right_knee":3,
686
+ # "spine2":3,
687
+ "left_ankle":3,
688
+ "right_ankle":3,
689
+ # "spine3":3,
690
+ "left_foot":3,
691
+ "right_foot":3,
692
+ # "neck":3,
693
+ # "left_collar":3,
694
+ # "right_collar":3,
695
+ # "head":3,
696
+ # "left_shoulder":3,
697
+ # "right_shoulder":3,
698
+ # "left_elbow":3,
699
+ # "right_elbow":3,
700
+ # "left_wrist":3,
701
+ # "right_wrist":3,
702
+ # "jaw":3,
703
+ # "left_eye_smplhf":3,
704
+ # "right_eye_smplhf":3,
705
+ # "left_index1":3,
706
+ # "left_index2":3,
707
+ # "left_index3":3,
708
+ # "left_middle1":3,
709
+ # "left_middle2":3,
710
+ # "left_middle3":3,
711
+ # "left_pinky1":3,
712
+ # "left_pinky2":3,
713
+ # "left_pinky3":3,
714
+ # "left_ring1":3,
715
+ # "left_ring2":3,
716
+ # "left_ring3":3,
717
+ # "left_thumb1":3,
718
+ # "left_thumb2":3,
719
+ # "left_thumb3":3,
720
+ # "right_index1":3,
721
+ # "right_index2":3,
722
+ # "right_index3":3,
723
+ # "right_middle1":3,
724
+ # "right_middle2":3,
725
+ # "right_middle3":3,
726
+ # "right_pinky1":3,
727
+ # "right_pinky2":3,
728
+ # "right_pinky3":3,
729
+ # "right_ring1":3,
730
+ # "right_ring2":3,
731
+ # "right_ring3":3,
732
+ # "right_thumb1":3,
733
+ # "right_thumb2":3,
734
+ # "right_thumb3":3,
735
+ },
736
+
737
+ "beat_smplx_face": {
738
+ # "pelvis":3,
739
+ # "left_hip":3,
740
+ # "right_hip":3,
741
+ # # "spine1":3,
742
+ # "left_knee":3,
743
+ # "right_knee":3,
744
+ # # "spine2":3,
745
+ # "left_ankle":3,
746
+ # "right_ankle":3,
747
+ # # "spine3":3,
748
+ # "left_foot":3,
749
+ # "right_foot":3,
750
+ # "neck":3,
751
+ # "left_collar":3,
752
+ # "right_collar":3,
753
+ # "head":3,
754
+ # "left_shoulder":3,
755
+ # "right_shoulder":3,
756
+ # "left_elbow":3,
757
+ # "right_elbow":3,
758
+ # "left_wrist":3,
759
+ # "right_wrist":3,
760
+ "jaw":3,
761
+ # "left_eye_smplhf":3,
762
+ # "right_eye_smplhf":3,
763
+ # "left_index1":3,
764
+ # "left_index2":3,
765
+ # "left_index3":3,
766
+ # "left_middle1":3,
767
+ # "left_middle2":3,
768
+ # "left_middle3":3,
769
+ # "left_pinky1":3,
770
+ # "left_pinky2":3,
771
+ # "left_pinky3":3,
772
+ # "left_ring1":3,
773
+ # "left_ring2":3,
774
+ # "left_ring3":3,
775
+ # "left_thumb1":3,
776
+ # "left_thumb2":3,
777
+ # "left_thumb3":3,
778
+ # "right_index1":3,
779
+ # "right_index2":3,
780
+ # "right_index3":3,
781
+ # "right_middle1":3,
782
+ # "right_middle2":3,
783
+ # "right_middle3":3,
784
+ # "right_pinky1":3,
785
+ # "right_pinky2":3,
786
+ # "right_pinky3":3,
787
+ # "right_ring1":3,
788
+ # "right_ring2":3,
789
+ # "right_ring3":3,
790
+ # "right_thumb1":3,
791
+ # "right_thumb2":3,
792
+ # "right_thumb3":3,
793
+ },
794
+
795
+ "beat_joints": {
796
+ 'Hips': [6,6],
797
+ 'Spine': [3,9],
798
+ 'Spine1': [3,12],
799
+ 'Spine2': [3,15],
800
+ 'Spine3': [3,18],
801
+ 'Neck': [3,21],
802
+ 'Neck1': [3,24],
803
+ 'Head': [3,27],
804
+ 'HeadEnd': [3,30],
805
+
806
+ 'RShoulder': [3,33],
807
+ 'RArm': [3,36],
808
+ 'RArm1': [3,39],
809
+ 'RHand': [3,42],
810
+ 'RHandM1': [3,45],
811
+ 'RHandM2': [3,48],
812
+ 'RHandM3': [3,51],
813
+ 'RHandM4': [3,54],
814
+
815
+ 'RHandR': [3,57],
816
+ 'RHandR1': [3,60],
817
+ 'RHandR2': [3,63],
818
+ 'RHandR3': [3,66],
819
+ 'RHandR4': [3,69],
820
+
821
+ 'RHandP': [3,72],
822
+ 'RHandP1': [3,75],
823
+ 'RHandP2': [3,78],
824
+ 'RHandP3': [3,81],
825
+ 'RHandP4': [3,84],
826
+
827
+ 'RHandI': [3,87],
828
+ 'RHandI1': [3,90],
829
+ 'RHandI2': [3,93],
830
+ 'RHandI3': [3,96],
831
+ 'RHandI4': [3,99],
832
+
833
+ 'RHandT1': [3,102],
834
+ 'RHandT2': [3,105],
835
+ 'RHandT3': [3,108],
836
+ 'RHandT4': [3,111],
837
+
838
+ 'LShoulder': [3,114],
839
+ 'LArm': [3,117],
840
+ 'LArm1': [3,120],
841
+ 'LHand': [3,123],
842
+ 'LHandM1': [3,126],
843
+ 'LHandM2': [3,129],
844
+ 'LHandM3': [3,132],
845
+ 'LHandM4': [3,135],
846
+
847
+ 'LHandR': [3,138],
848
+ 'LHandR1': [3,141],
849
+ 'LHandR2': [3,144],
850
+ 'LHandR3': [3,147],
851
+ 'LHandR4': [3,150],
852
+
853
+ 'LHandP': [3,153],
854
+ 'LHandP1': [3,156],
855
+ 'LHandP2': [3,159],
856
+ 'LHandP3': [3,162],
857
+ 'LHandP4': [3,165],
858
+
859
+ 'LHandI': [3,168],
860
+ 'LHandI1': [3,171],
861
+ 'LHandI2': [3,174],
862
+ 'LHandI3': [3,177],
863
+ 'LHandI4': [3,180],
864
+
865
+ 'LHandT1': [3,183],
866
+ 'LHandT2': [3,186],
867
+ 'LHandT3': [3,189],
868
+ 'LHandT4': [3,192],
869
+
870
+ 'RUpLeg': [3,195],
871
+ 'RLeg': [3,198],
872
+ 'RFoot': [3,201],
873
+ 'RFootF': [3,204],
874
+ 'RToeBase': [3,207],
875
+ 'RToeBaseEnd': [3,210],
876
+
877
+ 'LUpLeg': [3,213],
878
+ 'LLeg': [3,216],
879
+ 'LFoot': [3,219],
880
+ 'LFootF': [3,222],
881
+ 'LToeBase': [3,225],
882
+ 'LToeBaseEnd': [3,228],},
883
+
884
+ "beat_full":{
885
+ 'Hips': 3,
886
+ 'Spine': 3 ,
887
+ 'Spine1': 3 ,
888
+ 'Spine2': 3 ,
889
+ 'Spine3': 3 ,
890
+ 'Neck': 3 ,
891
+ 'Neck1': 3 ,
892
+ 'Head' : 3,
893
+ 'HeadEnd' : 3,
894
+ 'RShoulder': 3 ,
895
+ 'RArm': 3 ,
896
+ 'RArm1': 3 ,
897
+ 'RHand': 3 ,
898
+ 'RHandM1': 3 ,
899
+ 'RHandM2': 3 ,
900
+ 'RHandM3': 3 ,
901
+ 'RHandM4': 3 ,
902
+ 'RHandR': 3 ,
903
+ 'RHandR1': 3 ,
904
+ 'RHandR2': 3 ,
905
+ 'RHandR3': 3 ,
906
+ 'RHandR4': 3 ,
907
+ 'RHandP': 3 ,
908
+ 'RHandP1': 3 ,
909
+ 'RHandP2': 3 ,
910
+ 'RHandP3': 3 ,
911
+ 'RHandP4': 3 ,
912
+ 'RHandI': 3 ,
913
+ 'RHandI1': 3 ,
914
+ 'RHandI2': 3 ,
915
+ 'RHandI3': 3 ,
916
+ 'RHandI4': 3 ,
917
+ 'RHandT1': 3 ,
918
+ 'RHandT2': 3 ,
919
+ 'RHandT3': 3 ,
920
+ 'RHandT4': 3 ,
921
+ 'LShoulder': 3 ,
922
+ 'LArm': 3 ,
923
+ 'LArm1': 3 ,
924
+ 'LHand': 3 ,
925
+ 'LHandM1': 3 ,
926
+ 'LHandM2': 3 ,
927
+ 'LHandM3': 3 ,
928
+ 'LHandM4': 3 ,
929
+ 'LHandR': 3 ,
930
+ 'LHandR1': 3 ,
931
+ 'LHandR2': 3 ,
932
+ 'LHandR3': 3 ,
933
+ 'LHandR4': 3 ,
934
+ 'LHandP': 3 ,
935
+ 'LHandP1': 3 ,
936
+ 'LHandP2': 3 ,
937
+ 'LHandP3': 3 ,
938
+ 'LHandP4': 3 ,
939
+ 'LHandI': 3 ,
940
+ 'LHandI1': 3 ,
941
+ 'LHandI2': 3 ,
942
+ 'LHandI3': 3 ,
943
+ 'LHandI4': 3 ,
944
+ 'LHandT1': 3 ,
945
+ 'LHandT2': 3 ,
946
+ 'LHandT3': 3 ,
947
+ 'LHandT4': 3 ,
948
+ 'RUpLeg': 3,
949
+ 'RLeg': 3,
950
+ 'RFoot': 3,
951
+ 'RFootF': 3,
952
+ 'RToeBase': 3,
953
+ 'RToeBaseEnd': 3,
954
+ 'LUpLeg': 3,
955
+ 'LLeg': 3,
956
+ 'LFoot': 3,
957
+ 'LFootF': 3,
958
+ 'LToeBase': 3,
959
+ 'LToeBaseEnd': 3,
960
+ },
961
+
962
+ "japanese_joints":{
963
+ 'Hips': [6,6],
964
+ 'Spine': [6,12],
965
+ 'Spine1': [6,18],
966
+ 'Spine2': [6,24],
967
+ 'Spine3': [6,30],
968
+ 'Neck': [6,36],
969
+ 'Neck1': [6,42],
970
+ 'Head': [6,48],
971
+ 'RShoulder': [6,54],
972
+ 'RArm': [6,60],
973
+ 'RArm1': [6,66],
974
+ 'RHand': [6,72],
975
+ 'RHandM1': [6,78],
976
+ 'RHandM2': [6,84],
977
+ 'RHandM3': [6,90],
978
+ 'RHandR': [6,96],
979
+ 'RHandR1': [6,102],
980
+ 'RHandR2': [6,108],
981
+ 'RHandR3': [6,114],
982
+ 'RHandP': [6,120],
983
+ 'RHandP1': [6,126],
984
+ 'RHandP2': [6,132],
985
+ 'RHandP3': [6,138],
986
+ 'RHandI': [6,144],
987
+ 'RHandI1': [6,150],
988
+ 'RHandI2': [6,156],
989
+ 'RHandI3': [6,162],
990
+ 'RHandT1': [6,168],
991
+ 'RHandT2': [6,174],
992
+ 'RHandT3': [6,180],
993
+ 'LShoulder': [6,186],
994
+ 'LArm': [6,192],
995
+ 'LArm1': [6,198],
996
+ 'LHand': [6,204],
997
+ 'LHandM1': [6,210],
998
+ 'LHandM2': [6,216],
999
+ 'LHandM3': [6,222],
1000
+ 'LHandR': [6,228],
1001
+ 'LHandR1': [6,234],
1002
+ 'LHandR2': [6,240],
1003
+ 'LHandR3': [6,246],
1004
+ 'LHandP': [6,252],
1005
+ 'LHandP1': [6,258],
1006
+ 'LHandP2': [6,264],
1007
+ 'LHandP3': [6,270],
1008
+ 'LHandI': [6,276],
1009
+ 'LHandI1': [6,282],
1010
+ 'LHandI2': [6,288],
1011
+ 'LHandI3': [6,294],
1012
+ 'LHandT1': [6,300],
1013
+ 'LHandT2': [6,306],
1014
+ 'LHandT3': [6,312],
1015
+ 'RUpLeg': [6,318],
1016
+ 'RLeg': [6,324],
1017
+ 'RFoot': [6,330],
1018
+ 'RFootF': [6,336],
1019
+ 'RToeBase': [6,342],
1020
+ 'LUpLeg': [6,348],
1021
+ 'LLeg': [6,354],
1022
+ 'LFoot': [6,360],
1023
+ 'LFootF': [6,366],
1024
+ 'LToeBase': [6,372],},
1025
+
1026
+ "yostar":{
1027
+ 'Hips': [6,6],
1028
+ 'Spine': [3,9],
1029
+ 'Spine1': [3,12],
1030
+ 'Bone040': [3,15],
1031
+ 'Bone041': [3,18],
1032
+
1033
+ 'Bone034': [3,21],
1034
+ 'Bone035': [3,24],
1035
+ 'Bone036': [3,27],
1036
+ 'Bone037': [3,30],
1037
+ 'Bone038': [3,33],
1038
+ 'Bone039': [3,36],
1039
+
1040
+ 'RibbonL1': [3,39],
1041
+ 'RibbonL1_end': [3,42],
1042
+
1043
+ 'Chest': [3,45],
1044
+ 'L_eri': [3,48],
1045
+ 'R_eri': [3,51],
1046
+ 'Neck': [3,54],
1047
+ 'Head': [3,57],
1048
+ 'Head_end': [3,60],
1049
+
1050
+ 'RBackHair_1': [3,63],
1051
+ 'RBackHair_2': [3,66],
1052
+ 'RBackHair_3': [3,69],
1053
+ 'RBackHair_4': [3,72],
1054
+ 'RBackHair_end': [3,75],
1055
+
1056
+ 'RFrontHair': [3,78],
1057
+ 'CFrontHair_1': [3,81],
1058
+ 'CFrontHair_2': [3,84],
1059
+ 'CFrontHair_3': [3,87],
1060
+ 'CFrontHair_emd': [3,90],
1061
+
1062
+ 'LFrontHair_1': [3,93],
1063
+ 'LFrontHair_2': [3,96],
1064
+ 'LFrontHair_3': [3,99],
1065
+
1066
+ 'LBackHair_1': [3,102],
1067
+ 'LBackHair_2': [3,105],
1068
+ 'LBackHair_3': [3,108],
1069
+ 'LBackHair_4': [3,111],
1070
+ 'LBackHair_end': [3,114],
1071
+
1072
+ 'LSideHair_1': [3,117],
1073
+ 'LSideHair_2': [3,120],
1074
+ 'LSideHair_3': [3,123],
1075
+ 'LSideHair_4': [3,126],
1076
+ 'LSideHair_5': [3,129],
1077
+ 'LSideHair_6': [3,132],
1078
+ 'LSideHair_7': [3,135],
1079
+ 'LSideHair_end': [3,138],
1080
+
1081
+ 'CBackHair_1': [3,141],
1082
+ 'CBackHair_2': [3,144],
1083
+ 'CBackHair_3': [3,147],
1084
+ 'CBackHair_4': [3,150],
1085
+ 'CBackHair_end': [3,153],
1086
+
1087
+ 'RSideHair_1': [3,156],
1088
+ 'RSideHair_2': [3,159],
1089
+ 'RSideHair_3': [3,162],
1090
+ 'RSideHair_4': [3,165],
1091
+
1092
+ 'RibbonR_1': [3,168],
1093
+ 'RibbonR_2': [3,171],
1094
+ 'RibbonR_3': [3,174],
1095
+
1096
+ 'RibbonL_1': [3,177],
1097
+ 'RibbonL_2': [3,180],
1098
+ 'RibbonL_3': [3,183],
1099
+
1100
+ 'LeftEye': [3,186],
1101
+ 'LeftEye_end': [3,189],
1102
+ 'RightEye': [3,192],
1103
+ 'RightEye_end': [3,195],
1104
+
1105
+ 'LeftShoulder': [3,198],
1106
+ 'LeftArm': [3,201],
1107
+ 'LeftForearm': [3,204],
1108
+ 'LeftHand': [3,207],
1109
+ 'LeftHandThumb1': [3,210],
1110
+ 'LeftHandThumb2': [3,213],
1111
+ 'LeftHandThumb3': [3,216],
1112
+ 'LeftHandThumb_end': [3,219],
1113
+
1114
+ 'LeftHandIndex1': [3,222],
1115
+ 'LeftHandIndex2': [3,225],
1116
+ 'LeftHandIndex3': [3,228],
1117
+ 'LeftHandIndex_end': [3,231],
1118
+
1119
+ 'LeftHandMiddle1': [3,234],
1120
+ 'LeftHandMiddle2': [3,237],
1121
+ 'LeftHandMiddle3': [3,240],
1122
+ 'LeftHandMiddle_end': [3,243],
1123
+
1124
+ 'LeftHandRing1': [3,246],
1125
+ 'LeftHandRing2': [3,249],
1126
+ 'LeftHandRing3': [3,252],
1127
+ 'LeftHandRing_end': [3,255],
1128
+
1129
+ 'LeftHandPinky1': [3,258],
1130
+ 'LeftHandPinky2': [3,261],
1131
+ 'LeftHandPinky3': [3,264],
1132
+ 'LeftHandPinky_end': [3,267],
1133
+
1134
+ 'RightShoulder': [3,270],
1135
+ 'RightArm': [3,273],
1136
+ 'RightForearm': [3,276],
1137
+ 'RightHand': [3,279],
1138
+ 'RightHandThumb1': [3,282],
1139
+ 'RightHandThumb2': [3,285],
1140
+ 'RightHandThumb3': [3,288],
1141
+ 'RightHandThumb_end': [3,291],
1142
+
1143
+ 'RightHandIndex1': [3,294],
1144
+ 'RightHandIndex2': [3,297],
1145
+ 'RightHandIndex3': [3,300],
1146
+ 'RightHandIndex_end': [3,303],
1147
+
1148
+ 'RightHandMiddle1': [3,306],
1149
+ 'RightHandMiddle2': [3,309],
1150
+ 'RightHandMiddle3': [3,312],
1151
+ 'RightHandMiddle_end': [3,315],
1152
+
1153
+ 'RightHandRing1': [3,318],
1154
+ 'RightHandRing2': [3,321],
1155
+ 'RightHandRing3': [3,324],
1156
+ 'RightHandRing_end': [3,327],
1157
+
1158
+ 'RightHandPinky1': [3,330],
1159
+ 'RightHandPinky2': [3,333],
1160
+ 'RightHandPinky3': [3,336],
1161
+ 'RightHandPinky_end': [3,339],
1162
+
1163
+ 'RibbonR1': [3,342],
1164
+ 'RibbonR1_end': [3,345],
1165
+ 'RibbonR2': [3,348],
1166
+ 'RibbonR2_end': [3,351],
1167
+ 'RibbonL2': [3,354],
1168
+ 'RibbonL2_end': [3,357],
1169
+
1170
+ 'LeftUpLeg': [3,360],
1171
+ 'LeftLeg': [3,363],
1172
+ 'LeftFoot': [3,366],
1173
+ 'LeftToe': [3,369],
1174
+ 'LeftToe_end': [3,372],
1175
+
1176
+ 'RightUpLeg': [3,375],
1177
+ 'RightLEg': [3,378],
1178
+ 'RightFoot': [3,381],
1179
+ 'RightToe': [3,384],
1180
+ 'RightToe_end': [3,387],
1181
+
1182
+ 'bone_skirtF00': [3, 390],
1183
+ 'bone_skirtF01': [3, 393],
1184
+ 'bone_skirtF02': [3, 396],
1185
+ 'bone_skirtF03': [3, 399],
1186
+ 'Bone020': [3, 402],
1187
+ 'Bone026': [3, 405],
1188
+
1189
+ 'bone_skirtF_R_00': [3, 408],
1190
+ 'bone_skirtF_R_01': [3, 411],
1191
+ 'bone_skirtF_R_02': [3, 414],
1192
+ 'bone_skirtF_R_03': [3, 417],
1193
+ 'Bone019': [3, 420],
1194
+ 'Bone028': [3, 423],
1195
+
1196
+ 'bone_skirtR00': [3, 426],
1197
+ 'bone_skirtR01': [3, 429],
1198
+ 'bone_skirtR02': [3, 432],
1199
+ 'bone_skirtR03': [3, 435],
1200
+ 'Bone018': [3, 438],
1201
+ 'Bone029': [3, 441],
1202
+
1203
+ 'bone_skirtF_L_00': [3, 444],
1204
+ 'bone_skirtF_L_01': [3, 447],
1205
+ 'bone_skirtF_L_02': [3, 450],
1206
+ 'bone_skirtF_L_03': [3, 453],
1207
+ 'Bone021': [3, 456],
1208
+ 'Bone027': [3, 459],
1209
+
1210
+ 'bone_skirtL00': [3, 462],
1211
+ 'bone_skirtL01': [3, 465],
1212
+ 'bone_skirtL02': [3, 468],
1213
+ 'bone_skirtL03': [3, 471],
1214
+ 'Bone022': [3, 474],
1215
+ 'Bone033': [3, 477],
1216
+
1217
+ 'bone_skirtB_L_00': [3, 480],
1218
+ 'bone_skirtB_L_01': [3, 483],
1219
+ 'bone_skirtB_L_02': [3, 486],
1220
+ 'bone_skirtB_L_03': [3, 489],
1221
+ 'Bone023': [3, 492],
1222
+ 'Bone032': [3, 495],
1223
+
1224
+ 'bone_skirtB00': [3, 498],
1225
+ 'bone_skirtB01': [3, 501],
1226
+ 'bone_skirtB02': [3, 504],
1227
+ 'bone_skirtB03': [3, 507],
1228
+ 'Bone024': [3, 510],
1229
+ 'Bone031': [3, 513],
1230
+
1231
+ 'bone_skirtB_R_00': [3, 516],
1232
+ 'bone_skirtB_R_01': [3, 519],
1233
+ 'bone_skirtB_R_02': [3, 521],
1234
+ 'bone_skirtB_R_03': [3, 524],
1235
+ 'Bone025': [3, 527],
1236
+ 'Bone030': [3, 530],
1237
+ },
1238
+
1239
+ "yostar_fullbody_213":{
1240
+ 'Hips': 3 ,
1241
+ 'Spine': 3 ,
1242
+ 'Spine1': 3 ,
1243
+ 'Chest': 3 ,
1244
+ 'L_eri': 3 ,
1245
+ 'R_eri': 3 ,
1246
+ 'Neck': 3 ,
1247
+ 'Head': 3 ,
1248
+ 'Head_end': 3 ,
1249
+
1250
+ 'LeftEye': 3,
1251
+ 'LeftEye_end': 3,
1252
+ 'RightEye': 3,
1253
+ 'RightEye_end': 3,
1254
+
1255
+ 'LeftShoulder': 3,
1256
+ 'LeftArm': 3,
1257
+ 'LeftForearm': 3,
1258
+ 'LeftHand': 3,
1259
+ 'LeftHandThumb1': 3,
1260
+ 'LeftHandThumb2': 3,
1261
+ 'LeftHandThumb3': 3,
1262
+ 'LeftHandThumb_end': 3,
1263
+
1264
+ 'LeftHandIndex1': 3,
1265
+ 'LeftHandIndex2': 3,
1266
+ 'LeftHandIndex3': 3,
1267
+ 'LeftHandIndex_end': 3,
1268
+
1269
+ 'LeftHandMiddle1': 3,
1270
+ 'LeftHandMiddle2': 3,
1271
+ 'LeftHandMiddle3': 3,
1272
+ 'LeftHandMiddle_end': 3,
1273
+
1274
+ 'LeftHandRing1': 3,
1275
+ 'LeftHandRing2': 3,
1276
+ 'LeftHandRing3': 3,
1277
+ 'LeftHandRing_end': 3,
1278
+
1279
+ 'LeftHandPinky1': 3,
1280
+ 'LeftHandPinky2': 3,
1281
+ 'LeftHandPinky3': 3,
1282
+ 'LeftHandPinky_end':3,
1283
+
1284
+ 'RightShoulder': 3,
1285
+ 'RightArm': 3,
1286
+ 'RightForearm': 3,
1287
+ 'RightHand': 3,
1288
+ 'RightHandThumb1': 3,
1289
+ 'RightHandThumb2': 3,
1290
+ 'RightHandThumb3': 3,
1291
+ 'RightHandThumb_end': 3,
1292
+
1293
+ 'RightHandIndex1': 3,
1294
+ 'RightHandIndex2': 3,
1295
+ 'RightHandIndex3': 3,
1296
+ 'RightHandIndex_end': 3,
1297
+
1298
+ 'RightHandMiddle1': 3,
1299
+ 'RightHandMiddle2': 3,
1300
+ 'RightHandMiddle3': 3,
1301
+ 'RightHandMiddle_end': 3,
1302
+
1303
+ 'RightHandRing1': 3,
1304
+ 'RightHandRing2': 3,
1305
+ 'RightHandRing3': 3,
1306
+ 'RightHandRing_end': 3,
1307
+
1308
+ 'RightHandPinky1': 3,
1309
+ 'RightHandPinky2': 3,
1310
+ 'RightHandPinky3': 3,
1311
+ 'RightHandPinky_end': 3,
1312
+
1313
+ 'LeftUpLeg': 3,
1314
+ 'LeftLeg': 3,
1315
+ 'LeftFoot': 3,
1316
+ 'LeftToe': 3,
1317
+ 'LeftToe_end': 3,
1318
+
1319
+ 'RightUpLeg': 3,
1320
+ 'RightLEg': 3,
1321
+ 'RightFoot': 3,
1322
+ 'RightToe': 3,
1323
+ 'RightToe_end': 3,
1324
+ },
1325
+ "yostar_mainbody_48": {
1326
+ #'Hips': 3 ,
1327
+ 'Spine': 3 ,
1328
+ 'Spine1': 3 ,
1329
+ 'Chest': 3 ,
1330
+ 'L_eri': 3 ,
1331
+ 'R_eri': 3 ,
1332
+ 'Neck': 3 ,
1333
+ 'Head': 3 ,
1334
+ 'Head_end': 3 ,
1335
+
1336
+ 'LeftShoulder': 3,
1337
+ 'LeftArm': 3,
1338
+ 'LeftForearm': 3,
1339
+ 'LeftHand': 3,
1340
+
1341
+ 'RightShoulder': 3,
1342
+ 'RightArm': 3,
1343
+ 'RightForearm': 3,
1344
+ 'RightHand': 3,
1345
+ },
1346
+ "yostar_mainbody_69": {
1347
+ 'Hips': 3 ,
1348
+ 'Spine': 3 ,
1349
+ 'Spine1': 3 ,
1350
+ 'Chest': 3 ,
1351
+ 'L_eri': 3 ,
1352
+ 'R_eri': 3 ,
1353
+ 'Neck': 3 ,
1354
+ 'Head': 3 ,
1355
+ 'Head_end': 3 ,
1356
+
1357
+ 'LeftShoulder': 3,
1358
+ 'LeftArm': 3,
1359
+ 'LeftForearm': 3,
1360
+ 'LeftHand': 3,
1361
+
1362
+ 'RightShoulder': 3,
1363
+ 'RightArm': 3,
1364
+ 'RightForearm': 3,
1365
+ 'RightHand': 3,
1366
+
1367
+ 'LeftUpLeg': 3,
1368
+ 'LeftLeg': 3,
1369
+ 'LeftFoot': 3,
1370
+
1371
+ 'RightUpLeg': 3,
1372
+ 'RightLEg': 3,
1373
+ 'RightFoot': 3,
1374
+ },
1375
+
1376
+ "yostar_upbody_168": {
1377
+ #'Hips': 3 ,
1378
+ 'Spine': 3 ,
1379
+ 'Spine1': 3 ,
1380
+ 'Chest': 3 ,
1381
+ 'L_eri': 3 ,
1382
+ 'R_eri': 3 ,
1383
+ 'Neck': 3 ,
1384
+ 'Head': 3 ,
1385
+ 'Head_end': 3 ,
1386
+
1387
+ 'LeftShoulder': 3,
1388
+ 'LeftArm': 3,
1389
+ 'LeftForearm': 3,
1390
+ 'LeftHand': 3,
1391
+ 'LeftHandThumb1': 3,
1392
+ 'LeftHandThumb2': 3,
1393
+ 'LeftHandThumb3': 3,
1394
+ 'LeftHandThumb_end': 3,
1395
+
1396
+ 'LeftHandIndex1': 3,
1397
+ 'LeftHandIndex2': 3,
1398
+ 'LeftHandIndex3': 3,
1399
+ 'LeftHandIndex_end': 3,
1400
+
1401
+ 'LeftHandMiddle1': 3,
1402
+ 'LeftHandMiddle2': 3,
1403
+ 'LeftHandMiddle3': 3,
1404
+ 'LeftHandMiddle_end': 3,
1405
+
1406
+ 'LeftHandRing1': 3,
1407
+ 'LeftHandRing2': 3,
1408
+ 'LeftHandRing3': 3,
1409
+ 'LeftHandRing_end': 3,
1410
+
1411
+ 'LeftHandPinky1': 3,
1412
+ 'LeftHandPinky2': 3,
1413
+ 'LeftHandPinky3': 3,
1414
+ 'LeftHandPinky_end':3,
1415
+
1416
+ 'RightShoulder': 3,
1417
+ 'RightArm': 3,
1418
+ 'RightForearm': 3,
1419
+ 'RightHand': 3,
1420
+ 'RightHandThumb1': 3,
1421
+ 'RightHandThumb2': 3,
1422
+ 'RightHandThumb3': 3,
1423
+ 'RightHandThumb_end': 3,
1424
+
1425
+ 'RightHandIndex1': 3,
1426
+ 'RightHandIndex2': 3,
1427
+ 'RightHandIndex3': 3,
1428
+ 'RightHandIndex_end': 3,
1429
+
1430
+ 'RightHandMiddle1': 3,
1431
+ 'RightHandMiddle2': 3,
1432
+ 'RightHandMiddle3': 3,
1433
+ 'RightHandMiddle_end': 3,
1434
+
1435
+ 'RightHandRing1': 3,
1436
+ 'RightHandRing2': 3,
1437
+ 'RightHandRing3': 3,
1438
+ 'RightHandRing_end': 3,
1439
+
1440
+ 'RightHandPinky1': 3,
1441
+ 'RightHandPinky2': 3,
1442
+ 'RightHandPinky3': 3,
1443
+ 'RightHandPinky_end': 3,
1444
+ },
1445
+ "spine_neck_141":{
1446
+ 'Spine': 3 ,
1447
+ 'Neck': 3 ,
1448
+ 'Neck1': 3 ,
1449
+ 'RShoulder': 3 ,
1450
+ 'RArm': 3 ,
1451
+ 'RArm1': 3 ,
1452
+ 'RHand': 3 ,
1453
+ 'RHandM1': 3 ,
1454
+ 'RHandM2': 3 ,
1455
+ 'RHandM3': 3 ,
1456
+ 'RHandR': 3 ,
1457
+ 'RHandR1': 3 ,
1458
+ 'RHandR2': 3 ,
1459
+ 'RHandR3': 3 ,
1460
+ 'RHandP': 3 ,
1461
+ 'RHandP1': 3 ,
1462
+ 'RHandP2': 3 ,
1463
+ 'RHandP3': 3 ,
1464
+ 'RHandI': 3 ,
1465
+ 'RHandI1': 3 ,
1466
+ 'RHandI2': 3 ,
1467
+ 'RHandI3': 3 ,
1468
+ 'RHandT1': 3 ,
1469
+ 'RHandT2': 3 ,
1470
+ 'RHandT3': 3 ,
1471
+ 'LShoulder': 3 ,
1472
+ 'LArm': 3 ,
1473
+ 'LArm1': 3 ,
1474
+ 'LHand': 3 ,
1475
+ 'LHandM1': 3 ,
1476
+ 'LHandM2': 3 ,
1477
+ 'LHandM3': 3 ,
1478
+ 'LHandR': 3 ,
1479
+ 'LHandR1': 3 ,
1480
+ 'LHandR2': 3 ,
1481
+ 'LHandR3': 3 ,
1482
+ 'LHandP': 3 ,
1483
+ 'LHandP1': 3 ,
1484
+ 'LHandP2': 3 ,
1485
+ 'LHandP3': 3 ,
1486
+ 'LHandI': 3 ,
1487
+ 'LHandI1': 3 ,
1488
+ 'LHandI2': 3 ,
1489
+ 'LHandI3': 3 ,
1490
+ 'LHandT1': 3 ,
1491
+ 'LHandT2': 3 ,
1492
+ 'LHandT3': 3 ,},
1493
+ }
1494
+
1495
+
1496
+ class FIDCalculator(object):
1497
+ '''
1498
+ todo
1499
+ '''
1500
+ def __init__(self):
1501
+ self.gt_rot = None # pandas dataframe for n frames * joints * 6
1502
+ self.gt_pos = None # n frames * (joints + 13) * 3
1503
+ self.op_rot = None # pandas dataframe for n frames * joints * 6
1504
+ self.op_pos = None # n frames * (joints + 13) * 3
1505
+
1506
+
1507
+ def load(self, path, load_type, save_pos=False):
1508
+ '''
1509
+ select gt or op for load_type
1510
+ '''
1511
+ parser = BVHParser()
1512
+ parsed_data = parser.parse(path)
1513
+ if load_type == 'gt':
1514
+ self.gt_rot = parsed_data.values
1515
+ elif load_type == 'op':
1516
+ self.op_rot = parsed_data.values
1517
+ else: print('error, select gt or op for load_type')
1518
+
1519
+ if save_pos:
1520
+ mp = MocapParameterizer('position')
1521
+ positions = mp.fit_transform([parsed_data])
1522
+ if load_type == 'gt':
1523
+ self.gt_pos = positions[0].values
1524
+ elif load_type == 'op':
1525
+ self.op_pos = positions[0].values
1526
+ else: print('error, select gt or op for load_type')
1527
+
1528
+
1529
+ def _joint_selector(self, selected_joints, ori_data):
1530
+ selected_data = pd.DataFrame(columns=[])
1531
+
1532
+ for joint_name in selected_joints:
1533
+ selected_data[joint_name] = ori_data[joint_name]
1534
+ return selected_data.to_numpy()
1535
+
1536
+
1537
+ def cal_vol(self, dtype):
1538
+ if dtype == 'pos':
1539
+ gt = self.gt_pos
1540
+ op = self.op_pos
1541
+ else:
1542
+ gt = self.gt_rot
1543
+ op = self.op_rot
1544
+
1545
+ gt_v = gt.to_numpy()[1:, :] - gt.to_numpy()[0:-1, :]
1546
+ op_v = op.to_numpy()[1:, :] - op.to_numpy()[0:-1, :]
1547
+ if dtype == 'pos':
1548
+ self.gt_vol_pos = pd.DataFrame(gt_v, columns = gt.columns.tolist())
1549
+ self.op_vol_pos = pd.DataFrame(op_v, columns = gt.columns.tolist())
1550
+ else:
1551
+ self.gt_vol_rot = pd.DataFrame(gt_v, columns = gt.columns.tolist())
1552
+ self.op_vol_rot = pd.DataFrame(op_v, columns = gt.columns.tolist())
1553
+
1554
+
1555
+ @staticmethod
1556
+ def frechet_distance(samples_A, samples_B):
1557
+ A_mu = np.mean(samples_A, axis=0)
1558
+ A_sigma = np.cov(samples_A, rowvar=False)
1559
+ B_mu = np.mean(samples_B, axis=0)
1560
+ B_sigma = np.cov(samples_B, rowvar=False)
1561
+ try:
1562
+ frechet_dist = FIDCalculator.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
1563
+ except ValueError:
1564
+ frechet_dist = 1e+10
1565
+ return frechet_dist
1566
+
1567
+
1568
+ @staticmethod
1569
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
1570
+ """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
1571
+ """Numpy implementation of the Frechet Distance.
1572
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
1573
+ and X_2 ~ N(mu_2, C_2) is
1574
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
1575
+ Stable version by Dougal J. Sutherland.
1576
+ Params:
1577
+ -- mu1 : Numpy array containing the activations of a layer of the
1578
+ inception net (like returned by the function 'get_predictions')
1579
+ for generated samples.
1580
+ -- mu2 : The sample mean over activations, precalculated on an
1581
+ representative data set.
1582
+ -- sigma1: The covariance matrix over activations for generated samples.
1583
+ -- sigma2: The covariance matrix over activations, precalculated on an
1584
+ representative data set.
1585
+ Returns:
1586
+ -- : The Frechet Distance.
1587
+ """
1588
+
1589
+ mu1 = np.atleast_1d(mu1)
1590
+ mu2 = np.atleast_1d(mu2)
1591
+ #print(mu1[0], mu2[0])
1592
+ sigma1 = np.atleast_2d(sigma1)
1593
+ sigma2 = np.atleast_2d(sigma2)
1594
+ #print(sigma1[0], sigma2[0])
1595
+ assert mu1.shape == mu2.shape, \
1596
+ 'Training and test mean vectors have different lengths'
1597
+ assert sigma1.shape == sigma2.shape, \
1598
+ 'Training and test covariances have different dimensions'
1599
+
1600
+ diff = mu1 - mu2
1601
+
1602
+ # Product might be almost singular
1603
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
1604
+ #print(diff, covmean[0])
1605
+ if not np.isfinite(covmean).all():
1606
+ msg = ('fid calculation produces singular product; '
1607
+ 'adding %s to diagonal of cov estimates') % eps
1608
+ print(msg)
1609
+ offset = np.eye(sigma1.shape[0]) * eps
1610
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
1611
+
1612
+ # Numerical error might give slight imaginary component
1613
+ if np.iscomplexobj(covmean):
1614
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
1615
+ m = np.max(np.abs(covmean.imag))
1616
+ raise ValueError('Imaginary component {}'.format(m))
1617
+ covmean = covmean.real
1618
+
1619
+ tr_covmean = np.trace(covmean)
1620
+
1621
+ return (diff.dot(diff) + np.trace(sigma1) +
1622
+ np.trace(sigma2) - 2 * tr_covmean)
1623
+
1624
+
1625
+ def calculate_fid(self, cal_type, joint_type, high_level_opt):
1626
+
1627
+ if cal_type == 'pos':
1628
+ if self.gt_pos.shape != self.op_pos.shape:
1629
+ min_val = min(self.gt_pos.shape[0],self.op_pos.shape[0])
1630
+ gt = self.gt_pos[:min_val]
1631
+ op = self.op_pos[:min_val]
1632
+ else:
1633
+ gt = self.gt_pos
1634
+ op = self.op_pos
1635
+ full_body = gt.columns.tolist()
1636
+ elif cal_type == 'rot':
1637
+ if self.gt_rot.shape != self.op_rot.shape:
1638
+ min_val = min(self.gt_rot.shape[0],self.op_rot.shape[0])
1639
+ gt = self.gt_rot[:min_val]
1640
+ op = self.op_rot[:min_val]
1641
+ else:
1642
+ gt = self.gt_rot
1643
+ op = self.op_rot
1644
+ full_body_with_offset = gt.columns.tolist()
1645
+ full_body = [o for o in full_body_with_offset if ('position' not in o)]
1646
+ elif cal_type == 'pos_vol':
1647
+ assert self.gt_vol_pos.shape == self.op_vol_pos.shape
1648
+ gt = self.gt_vol_pos
1649
+ op = self.op_vol_pos
1650
+ full_body_with_offset = gt.columns.tolist()
1651
+ full_body = gt.columns.tolist()
1652
+ elif cal_type == 'rot_vol':
1653
+ assert self.gt_vol_rot.shape == self.op_vol_rot.shape
1654
+ gt = self.gt_vol_rot
1655
+ op = self.op_vol_rot
1656
+ full_body_with_offset = gt.columns.tolist()
1657
+ full_body = [o for o in full_body_with_offset if ('position' not in o)]
1658
+ #print(f'full_body contains {len(full_body)//3} joints')
1659
+
1660
+ if joint_type == 'full_upper_body':
1661
+ selected_body = [o for o in full_body if ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)]
1662
+ elif joint_type == 'upper_body':
1663
+ selected_body = [o for o in full_body if ('Hand' not in o) and ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)]
1664
+ elif joint_type == 'fingers':
1665
+ selected_body = [o for o in full_body if ('Hand' in o)]
1666
+ elif joint_type == 'indivdual':
1667
+ pass
1668
+ else: print('error, plz select correct joint type')
1669
+ #print(f'calculate fid for {len(selected_body)//3} joints')
1670
+
1671
+ gt = self._joint_selector(selected_body, gt)
1672
+ op = self._joint_selector(selected_body, op)
1673
+
1674
+ if high_level_opt == 'fid':
1675
+ fid = FIDCalculator.frechet_distance(gt, op)
1676
+ return fid
1677
+ elif high_level_opt == 'var':
1678
+ var_gt = gt.var()
1679
+ var_op = op.var()
1680
+ return var_gt, var_op
1681
+ elif high_level_opt == 'mean':
1682
+ mean_gt = gt.mean()
1683
+ mean_op = op.mean()
1684
+ return mean_gt, mean_op
1685
+ else: return 0
1686
+
1687
+
1688
+ def result2target_vis(pose_version, res_bvhlist, save_path, demo_name, verbose=True):
1689
+ if "trinity" in pose_version:
1690
+ ori_list = joints_list[pose_version[6:-4]]
1691
+ target_list = joints_list[pose_version[6:]]
1692
+ file_content_length = 336
1693
+ elif "beat" in pose_version or "spine_neck_141" in pose_version:
1694
+ ori_list = joints_list["beat_joints"]
1695
+ target_list = joints_list["spine_neck_141"]
1696
+ file_content_length = 431
1697
+ elif "yostar" in pose_version:
1698
+ ori_list = joints_list["yostar"]
1699
+ target_list = joints_list[pose_version]
1700
+ file_content_length = 1056
1701
+ else:
1702
+ ori_list = joints_list["japanese_joints"]
1703
+ target_list = joints_list[pose_version]
1704
+ file_content_length = 366
1705
+
1706
+ bvh_files_dirs = sorted(glob.glob(f'{res_bvhlist}*.bvh'), key=str)
1707
+ #test_seq_list = os.list_dir(demo_name).sort()
1708
+
1709
+ counter = 0
1710
+ if not os.path.exists(save_path):
1711
+ os.makedirs(save_path)
1712
+ for i, bvh_file_dir in enumerate(bvh_files_dirs):
1713
+ short_name = bvh_file_dir.split("/")[-1][11:]
1714
+ #print(short_name)
1715
+ wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+')
1716
+ with open(f"{demo_name}{short_name}",'r') as pose_data_pre:
1717
+ pose_data_pre_file = pose_data_pre.readlines()
1718
+ for j, line in enumerate(pose_data_pre_file[0:file_content_length]):
1719
+ wirte_file.write(line)
1720
+ offset_data = pose_data_pre_file[file_content_length]
1721
+ offset_data = np.fromstring(offset_data, dtype=float, sep=' ')
1722
+ wirte_file.close()
1723
+
1724
+ wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'r')
1725
+ ori_lines = wirte_file.readlines()
1726
+ with open(bvh_file_dir, 'r') as pose_data:
1727
+ pose_data_file = pose_data.readlines()
1728
+ ori_lines[file_content_length-2] = 'Frames: ' + str(len(pose_data_file)-1) + '\n'
1729
+ wirte_file.close()
1730
+
1731
+ wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+')
1732
+ wirte_file.writelines(i for i in ori_lines[:file_content_length])
1733
+ wirte_file.close()
1734
+
1735
+ with open(os.path.join(save_path, f'res_{short_name}'),'a+') as wirte_file:
1736
+ with open(bvh_file_dir, 'r') as pose_data:
1737
+ data_each_file = []
1738
+ pose_data_file = pose_data.readlines()
1739
+ for j, line in enumerate(pose_data_file):
1740
+ if not j:
1741
+ pass
1742
+ else:
1743
+ data = np.fromstring(line, dtype=float, sep=' ')
1744
+ data_rotation = offset_data.copy()
1745
+ for iii, (k, v) in enumerate(target_list.items()): # here is 147 rotations by 3
1746
+ #print(data_rotation[ori_list[k][1]-v:ori_list[k][1]], data[iii*3:iii*3+3])
1747
+ data_rotation[ori_list[k][1]-v:ori_list[k][1]] = data[iii*3:iii*3+3]
1748
+ data_each_file.append(data_rotation)
1749
+
1750
+ for line_data in data_each_file:
1751
+ line_data = np.array2string(line_data, max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
1752
+ wirte_file.write(line_data[1:-2]+'\n')
1753
+
1754
+ counter += 1
1755
+ if verbose:
1756
+ logger.info('data_shape:', data_rotation.shape, 'process:', counter, '/', len(bvh_files_dirs))
dataloaders/mix_sep.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import math
4
+ import shutil
5
+ import numpy as np
6
+ import lmdb as lmdb
7
+ import textgrid as tg
8
+ import pandas as pd
9
+ import torch
10
+ import glob
11
+ import json
12
+ from termcolor import colored
13
+ from loguru import logger
14
+ from collections import defaultdict
15
+ from torch.utils.data import Dataset
16
+ import torch.distributed as dist
17
+ #import pyarrow
18
+ import pickle
19
+ import librosa
20
+ import smplx
21
+ import glob
22
+
23
+ from .build_vocab import Vocab
24
+ from .utils.audio_features import Wav2Vec2Model
25
+ from .data_tools import joints_list
26
+ from .utils import rotation_conversions as rc
27
+ from .utils import other_tools
28
+
29
+
30
+ class CustomDataset(Dataset):
31
+ def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
32
+ self.args = args
33
+ self.loader_type = loader_type
34
+
35
+ self.rank = 0
36
+ self.ori_stride = self.args.stride
37
+ self.ori_length = self.args.pose_length
38
+
39
+ self.ori_joint_list = joints_list[self.args.ori_joints]
40
+ self.tar_joint_list = joints_list[self.args.tar_joints]
41
+ if 'smplx' in self.args.pose_rep:
42
+ self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
43
+ self.joints = len(list(self.tar_joint_list.keys()))
44
+ for joint_name in self.tar_joint_list:
45
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
46
+ else:
47
+ self.joints = len(list(self.ori_joint_list.keys()))+1
48
+ self.joint_mask = np.zeros(self.joints*3)
49
+ for joint_name in self.tar_joint_list:
50
+ if joint_name == "Hips":
51
+ self.joint_mask[3:6] = 1
52
+ else:
53
+ self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
54
+ # select trainable joints
55
+
56
+ split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
57
+ self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
58
+ if args.additional_data and loader_type == 'train':
59
+ split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
60
+ #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
61
+ self.selected_file = pd.concat([self.selected_file, split_b])
62
+ if self.selected_file.empty:
63
+ logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
64
+ self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
65
+ self.selected_file = self.selected_file.iloc[0:8]
66
+ self.data_dir = args.data_path
67
+ self.beatx_during_time = 0
68
+
69
+ if loader_type == "test":
70
+ self.args.multi_length_training = [1.0]
71
+ self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
72
+ self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
73
+ if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
74
+ self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
75
+ preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
76
+
77
+
78
+ if build_cache and self.rank == 0:
79
+ self.build_cache(preloaded_dir)
80
+ self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
81
+ with self.lmdb_env.begin() as txn:
82
+ self.n_samples = txn.stat()["entries"]
83
+
84
+ self.norm = True
85
+ self.mean = np.load('./mean_std/beatx_2_330_mean.npy')
86
+ self.std = np.load('./mean_std/beatx_2_330_std.npy')
87
+
88
+ self.trans_mean = np.load('./mean_std/beatx_2_trans_mean.npy')
89
+ self.trans_std = np.load('./mean_std/beatx_2_trans_std.npy')
90
+
91
+ def build_cache(self, preloaded_dir):
92
+ logger.info(f"Audio bit rate: {self.args.audio_fps}")
93
+ logger.info("Reading data '{}'...".format(self.data_dir))
94
+ logger.info("Creating the dataset cache...")
95
+ if self.args.new_cache:
96
+ if os.path.exists(preloaded_dir):
97
+ shutil.rmtree(preloaded_dir)
98
+ if os.path.exists(preloaded_dir):
99
+ logger.info("Found the cache {}".format(preloaded_dir))
100
+ elif self.loader_type == "test":
101
+ self.cache_generation(
102
+ preloaded_dir, True,
103
+ 0, 0,
104
+ is_test=True)
105
+ else:
106
+ self.cache_generation(
107
+ preloaded_dir, self.args.disable_filtering,
108
+ self.args.clean_first_seconds, self.args.clean_final_seconds,
109
+ is_test=False)
110
+ logger.info(f"BEATX during time is {self.beatx_during_time}s !")
111
+
112
+
113
+ def __len__(self):
114
+ return self.n_samples
115
+
116
+
117
+ def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
118
+ self.n_out_samples = 0
119
+ # create db for samples
120
+ if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
121
+ dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
122
+ n_filtered_out = defaultdict(int)
123
+
124
+ for index, file_name in self.selected_file.iterrows():
125
+ f_name = file_name["id"]
126
+ ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
127
+ pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
128
+ pose_each_file = []
129
+ trans_each_file = []
130
+ trans_v_each_file = []
131
+ shape_each_file = []
132
+ audio_each_file = []
133
+ facial_each_file = []
134
+ word_each_file = []
135
+ emo_each_file = []
136
+ sem_each_file = []
137
+ vid_each_file = []
138
+ id_pose = f_name #1_wayne_0_1_1
139
+
140
+ logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
141
+ if "smplx" in self.args.pose_rep:
142
+ pose_data = np.load(pose_file, allow_pickle=True)
143
+ assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
144
+ stride = int(30/self.args.pose_fps)
145
+ pose_each_file = pose_data["poses"][::stride] * self.joint_mask
146
+ pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
147
+
148
+ self.beatx_during_time += pose_each_file.shape[0]/30
149
+ trans_each_file = pose_data["trans"][::stride]
150
+ trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0]
151
+ trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2]
152
+ trans_v_each_file = np.zeros_like(trans_each_file)
153
+ trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0]
154
+ trans_v_each_file[0,0] = trans_v_each_file[1,0]
155
+ trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2]
156
+ trans_v_each_file[0,2] = trans_v_each_file[1,2]
157
+ trans_v_each_file[:,1] = trans_each_file[:,1]
158
+
159
+
160
+ shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
161
+ if self.args.facial_rep is not None:
162
+ logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
163
+ facial_each_file = pose_data["expressions"][::stride]
164
+ if self.args.facial_norm:
165
+ facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
166
+
167
+ if self.args.id_rep is not None:
168
+ vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
169
+
170
+ filtered_result = self._sample_from_clip(
171
+ dst_lmdb_env,
172
+ pose_each_file, trans_each_file,trans_v_each_file, shape_each_file, facial_each_file,
173
+ vid_each_file,
174
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
175
+ )
176
+ for type in filtered_result.keys():
177
+ n_filtered_out[type] += filtered_result[type]
178
+
179
+ with dst_lmdb_env.begin() as txn:
180
+ logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
181
+ n_total_filtered = 0
182
+ for type, n_filtered in n_filtered_out.items():
183
+ logger.info("{}: {}".format(type, n_filtered))
184
+ n_total_filtered += n_filtered
185
+ logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
186
+ n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
187
+ dst_lmdb_env.sync()
188
+ dst_lmdb_env.close()
189
+
190
+ def _sample_from_clip(
191
+ self, dst_lmdb_env, pose_each_file, trans_each_file, trans_v_each_file, shape_each_file, facial_each_file,
192
+ vid_each_file,
193
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
194
+ ):
195
+ """
196
+ for data cleaning, we ignore the data for first and final n s
197
+ for test, we return all data
198
+ """
199
+
200
+ round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
201
+ #print(round_seconds_skeleton)
202
+
203
+ clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
204
+ clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
205
+ clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
206
+
207
+
208
+ for ratio in self.args.multi_length_training:
209
+ if is_test:# stride = length for test
210
+ cut_length = clip_e_f_pose - clip_s_f_pose
211
+ self.args.stride = cut_length
212
+ self.max_length = cut_length
213
+ else:
214
+ self.args.stride = int(ratio*self.ori_stride)
215
+ cut_length = int(self.ori_length*ratio)
216
+
217
+ num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
218
+ logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
219
+ logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
220
+
221
+
222
+ n_filtered_out = defaultdict(int)
223
+ sample_pose_list = []
224
+ sample_face_list = []
225
+ sample_shape_list = []
226
+ sample_vid_list = []
227
+ sample_trans_list = []
228
+ sample_trans_v_list = []
229
+
230
+ for i in range(num_subdivision): # cut into around 2s chip, (self npose)
231
+ start_idx = clip_s_f_pose + i * self.args.stride
232
+ fin_idx = start_idx + cut_length
233
+ sample_pose = pose_each_file[start_idx:fin_idx]
234
+ sample_trans = trans_each_file[start_idx:fin_idx]
235
+ sample_trans_v = trans_v_each_file[start_idx:fin_idx]
236
+ sample_shape = shape_each_file[start_idx:fin_idx]
237
+ sample_face = facial_each_file[start_idx:fin_idx]
238
+ # print(sample_pose.shape)
239
+ sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
240
+
241
+ if sample_pose.any() != None:
242
+ sample_pose_list.append(sample_pose)
243
+
244
+ sample_shape_list.append(sample_shape)
245
+
246
+ sample_vid_list.append(sample_vid)
247
+ sample_face_list.append(sample_face)
248
+
249
+
250
+ sample_trans_list.append(sample_trans)
251
+ sample_trans_v_list.append(sample_trans_v)
252
+
253
+ if len(sample_pose_list) > 0:
254
+ with dst_lmdb_env.begin(write=True) as txn:
255
+ for pose, shape, face, vid, trans,trans_v in zip(
256
+ sample_pose_list,
257
+ sample_shape_list,
258
+ sample_face_list,
259
+ sample_vid_list,
260
+ sample_trans_list,
261
+ sample_trans_v_list,
262
+ ):
263
+ k = "{:005}".format(self.n_out_samples).encode("ascii")
264
+ v = [pose , shape, face, vid, trans,trans_v]
265
+ v = pickle.dumps(v,5)
266
+ txn.put(k, v)
267
+ self.n_out_samples += 1
268
+ return n_filtered_out
269
+
270
+ def __getitem__(self, idx):
271
+ with self.lmdb_env.begin(write=False) as txn:
272
+ key = "{:005}".format(idx).encode("ascii")
273
+ sample = txn.get(key)
274
+ sample = pickle.loads(sample)
275
+ tar_pose, in_shape, tar_face, vid, trans,trans_v = sample
276
+ tar_pose = torch.from_numpy(tar_pose).float()
277
+ tar_face = torch.from_numpy(tar_face).float()
278
+ tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(-1, 55, 3))
279
+ tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(-1, 55*6)
280
+
281
+ if self.norm:
282
+ tar_pose = (tar_pose - self.mean) / self.std
283
+ trans_v = (trans_v-self.trans_mean)/self.trans_std
284
+
285
+ if self.loader_type == "test":
286
+ tar_pose = tar_pose.float()
287
+ trans = torch.from_numpy(trans).float()
288
+ trans_v = torch.from_numpy(trans_v).float()
289
+ vid = torch.from_numpy(vid).float()
290
+ in_shape = torch.from_numpy(in_shape).float()
291
+ tar_pose = torch.cat([tar_pose, trans_v], dim=1)
292
+ tar_pose = torch.cat([tar_pose, tar_face], dim=1)
293
+ else:
294
+ in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
295
+ trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
296
+ trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float()
297
+ vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
298
+ tar_pose = tar_pose.reshape((tar_pose.shape[0], -1)).float()
299
+ tar_pose = torch.cat([tar_pose, trans_v], dim=1)
300
+ tar_pose = torch.cat([tar_pose, tar_face], dim=1)
301
+ return tar_pose
dataloaders/pymo/Quaternions.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Quaternions:
4
+ """
5
+ Quaternions is a wrapper around a numpy ndarray
6
+ that allows it to act as if it were an narray of
7
+ a quaternion data type.
8
+
9
+ Therefore addition, subtraction, multiplication,
10
+ division, negation, absolute, are all defined
11
+ in terms of quaternion operations such as quaternion
12
+ multiplication.
13
+
14
+ This allows for much neater code and many routines
15
+ which conceptually do the same thing to be written
16
+ in the same way for point data and for rotation data.
17
+
18
+ The Quaternions class has been desgined such that it
19
+ should support broadcasting and slicing in all of the
20
+ usual ways.
21
+ """
22
+
23
+ def __init__(self, qs):
24
+ if isinstance(qs, np.ndarray):
25
+
26
+ if len(qs.shape) == 1: qs = np.array([qs])
27
+ self.qs = qs
28
+ return
29
+
30
+ if isinstance(qs, Quaternions):
31
+ self.qs = qs.qs
32
+ return
33
+
34
+ raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
35
+
36
+ def __str__(self): return "Quaternions("+ str(self.qs) + ")"
37
+ def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
38
+
39
+ """ Helper Methods for Broadcasting and Data extraction """
40
+
41
+ @classmethod
42
+ def _broadcast(cls, sqs, oqs, scalar=False):
43
+
44
+ if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])
45
+
46
+ ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
47
+ os = np.array(oqs.shape)
48
+
49
+ if len(ss) != len(os):
50
+ raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
51
+
52
+ if np.all(ss == os): return sqs, oqs
53
+
54
+ if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
55
+ raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
56
+
57
+ sqsn, oqsn = sqs.copy(), oqs.copy()
58
+
59
+ for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)
60
+ for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)
61
+
62
+ return sqsn, oqsn
63
+
64
+ """ Adding Quaterions is just Defined as Multiplication """
65
+
66
+ def __add__(self, other): return self * other
67
+ def __sub__(self, other): return self / other
68
+
69
+ """ Quaterion Multiplication """
70
+
71
+ def __mul__(self, other):
72
+ """
73
+ Quaternion multiplication has three main methods.
74
+
75
+ When multiplying a Quaternions array by Quaternions
76
+ normal quaternion multiplication is performed.
77
+
78
+ When multiplying a Quaternions array by a vector
79
+ array of the same shape, where the last axis is 3,
80
+ it is assumed to be a Quaternion by 3D-Vector
81
+ multiplication and the 3D-Vectors are rotated
82
+ in space by the Quaternions.
83
+
84
+ When multipplying a Quaternions array by a scalar
85
+ or vector of different shape it is assumed to be
86
+ a Quaternions by Scalars multiplication and the
87
+ Quaternions are scaled using Slerp and the identity
88
+ quaternions.
89
+ """
90
+
91
+ """ If Quaternions type do Quaternions * Quaternions """
92
+ if isinstance(other, Quaternions):
93
+
94
+ sqs, oqs = Quaternions._broadcast(self.qs, other.qs)
95
+
96
+ q0 = sqs[...,0]; q1 = sqs[...,1];
97
+ q2 = sqs[...,2]; q3 = sqs[...,3];
98
+ r0 = oqs[...,0]; r1 = oqs[...,1];
99
+ r2 = oqs[...,2]; r3 = oqs[...,3];
100
+
101
+ qs = np.empty(sqs.shape)
102
+ qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
103
+ qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
104
+ qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
105
+ qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
106
+
107
+ return Quaternions(qs)
108
+
109
+ """ If array type do Quaternions * Vectors """
110
+ if isinstance(other, np.ndarray) and other.shape[-1] == 3:
111
+ vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))
112
+ return (self * (vs * -self)).imaginaries
113
+
114
+ """ If float do Quaternions * Scalars """
115
+ if isinstance(other, np.ndarray) or isinstance(other, float):
116
+ return Quaternions.slerp(Quaternions.id_like(self), self, other)
117
+
118
+ raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
119
+
120
+ def __div__(self, other):
121
+ """
122
+ When a Quaternion type is supplied, division is defined
123
+ as multiplication by the inverse of that Quaternion.
124
+
125
+ When a scalar or vector is supplied it is defined
126
+ as multiplicaion of one over the supplied value.
127
+ Essentially a scaling.
128
+ """
129
+
130
+ if isinstance(other, Quaternions): return self * (-other)
131
+ if isinstance(other, np.ndarray): return self * (1.0 / other)
132
+ if isinstance(other, float): return self * (1.0 / other)
133
+ raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
134
+
135
+ def __eq__(self, other): return self.qs == other.qs
136
+ def __ne__(self, other): return self.qs != other.qs
137
+
138
+ def __neg__(self):
139
+ """ Invert Quaternions """
140
+ return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
141
+
142
+ def __abs__(self):
143
+ """ Unify Quaternions To Single Pole """
144
+ qabs = self.normalized().copy()
145
+ top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)
146
+ bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)
147
+ qabs.qs[top < bot] = -qabs.qs[top < bot]
148
+ return qabs
149
+
150
+ def __iter__(self): return iter(self.qs)
151
+ def __len__(self): return len(self.qs)
152
+
153
+ def __getitem__(self, k): return Quaternions(self.qs[k])
154
+ def __setitem__(self, k, v): self.qs[k] = v.qs
155
+
156
+ @property
157
+ def lengths(self):
158
+ return np.sum(self.qs**2.0, axis=-1)**0.5
159
+
160
+ @property
161
+ def reals(self):
162
+ return self.qs[...,0]
163
+
164
+ @property
165
+ def imaginaries(self):
166
+ return self.qs[...,1:4]
167
+
168
+ @property
169
+ def shape(self): return self.qs.shape[:-1]
170
+
171
+ def repeat(self, n, **kwargs):
172
+ return Quaternions(self.qs.repeat(n, **kwargs))
173
+
174
+ def normalized(self):
175
+ return Quaternions(self.qs / self.lengths[...,np.newaxis])
176
+
177
+ def log(self):
178
+ norm = abs(self.normalized())
179
+ imgs = norm.imaginaries
180
+ lens = np.sqrt(np.sum(imgs**2, axis=-1))
181
+ lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
182
+ return imgs * lens[...,np.newaxis]
183
+
184
+ def constrained(self, axis):
185
+
186
+ rl = self.reals
187
+ im = np.sum(axis * self.imaginaries, axis=-1)
188
+
189
+ t1 = -2 * np.arctan2(rl, im) + np.pi
190
+ t2 = -2 * np.arctan2(rl, im) - np.pi
191
+
192
+ top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))
193
+ bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))
194
+ img = self.dot(top) > self.dot(bot)
195
+
196
+ ret = top.copy()
197
+ ret[ img] = top[ img]
198
+ ret[~img] = bot[~img]
199
+ return ret
200
+
201
+ def constrained_x(self): return self.constrained(np.array([1,0,0]))
202
+ def constrained_y(self): return self.constrained(np.array([0,1,0]))
203
+ def constrained_z(self): return self.constrained(np.array([0,0,1]))
204
+
205
+ def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
206
+
207
+ def copy(self): return Quaternions(np.copy(self.qs))
208
+
209
+ def reshape(self, s):
210
+ self.qs.reshape(s)
211
+ return self
212
+
213
+ def interpolate(self, ws):
214
+ return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
215
+
216
+ def euler(self, order='xyz'):
217
+
218
+ q = self.normalized().qs
219
+ q0 = q[...,0]
220
+ q1 = q[...,1]
221
+ q2 = q[...,2]
222
+ q3 = q[...,3]
223
+ es = np.zeros(self.shape + (3,))
224
+
225
+ if order == 'xyz':
226
+ es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
227
+ es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))
228
+ es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
229
+ elif order == 'yzx':
230
+ es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
231
+ es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
232
+ es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))
233
+ else:
234
+ raise NotImplementedError('Cannot convert from ordering %s' % order)
235
+
236
+ """
237
+
238
+ # These conversion don't appear to work correctly for Maya.
239
+ # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
240
+
241
+ if order == 'xyz':
242
+ es[...,0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
243
+ es[...,1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
244
+ es[...,2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
245
+ elif order == 'yzx':
246
+ es[...,0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
247
+ es[...,1] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
248
+ es[...,2] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
249
+ elif order == 'zxy':
250
+ es[...,0] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
251
+ es[...,1] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
252
+ es[...,2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
253
+ elif order == 'xzy':
254
+ es[...,0] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
255
+ es[...,1] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
256
+ es[...,2] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
257
+ elif order == 'yxz':
258
+ es[...,0] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
259
+ es[...,1] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
260
+ es[...,2] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
261
+ elif order == 'zyx':
262
+ es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
263
+ es[...,1] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
264
+ es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
265
+ else:
266
+ raise KeyError('Unknown ordering %s' % order)
267
+
268
+ """
269
+
270
+ # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
271
+ # Use this class and convert from matrix
272
+
273
+ return es
274
+
275
+
276
+ def average(self):
277
+
278
+ if len(self.shape) == 1:
279
+
280
+ import numpy.core.umath_tests as ut
281
+ system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)
282
+ w, v = np.linalg.eigh(system)
283
+ qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)
284
+ return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])
285
+
286
+ else:
287
+
288
+ raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')
289
+
290
+ def angle_axis(self):
291
+
292
+ norm = self.normalized()
293
+ s = np.sqrt(1 - (norm.reals**2.0))
294
+ s[s == 0] = 0.001
295
+
296
+ angles = 2.0 * np.arccos(norm.reals)
297
+ axis = norm.imaginaries / s[...,np.newaxis]
298
+
299
+ return angles, axis
300
+
301
+
302
+ def transforms(self):
303
+
304
+ qw = self.qs[...,0]
305
+ qx = self.qs[...,1]
306
+ qy = self.qs[...,2]
307
+ qz = self.qs[...,3]
308
+
309
+ x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;
310
+ xx = qx * x2; yy = qy * y2; wx = qw * x2;
311
+ xy = qx * y2; yz = qy * z2; wy = qw * y2;
312
+ xz = qx * z2; zz = qz * z2; wz = qw * z2;
313
+
314
+ m = np.empty(self.shape + (3,3))
315
+ m[...,0,0] = 1.0 - (yy + zz)
316
+ m[...,0,1] = xy - wz
317
+ m[...,0,2] = xz + wy
318
+ m[...,1,0] = xy + wz
319
+ m[...,1,1] = 1.0 - (xx + zz)
320
+ m[...,1,2] = yz - wx
321
+ m[...,2,0] = xz - wy
322
+ m[...,2,1] = yz + wx
323
+ m[...,2,2] = 1.0 - (xx + yy)
324
+
325
+ return m
326
+
327
+ def ravel(self):
328
+ return self.qs.ravel()
329
+
330
+ @classmethod
331
+ def id(cls, n):
332
+
333
+ if isinstance(n, tuple):
334
+ qs = np.zeros(n + (4,))
335
+ qs[...,0] = 1.0
336
+ return Quaternions(qs)
337
+
338
+ if isinstance(n, int) or isinstance(n, long):
339
+ qs = np.zeros((n,4))
340
+ qs[:,0] = 1.0
341
+ return Quaternions(qs)
342
+
343
+ raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))
344
+
345
+ @classmethod
346
+ def id_like(cls, a):
347
+ qs = np.zeros(a.shape + (4,))
348
+ qs[...,0] = 1.0
349
+ return Quaternions(qs)
350
+
351
+ @classmethod
352
+ def exp(cls, ws):
353
+
354
+ ts = np.sum(ws**2.0, axis=-1)**0.5
355
+ ts[ts == 0] = 0.001
356
+ ls = np.sin(ts) / ts
357
+
358
+ qs = np.empty(ws.shape[:-1] + (4,))
359
+ qs[...,0] = np.cos(ts)
360
+ qs[...,1] = ws[...,0] * ls
361
+ qs[...,2] = ws[...,1] * ls
362
+ qs[...,3] = ws[...,2] * ls
363
+
364
+ return Quaternions(qs).normalized()
365
+
366
+ @classmethod
367
+ def slerp(cls, q0s, q1s, a):
368
+
369
+ fst, snd = cls._broadcast(q0s.qs, q1s.qs)
370
+ fst, a = cls._broadcast(fst, a, scalar=True)
371
+ snd, a = cls._broadcast(snd, a, scalar=True)
372
+
373
+ len = np.sum(fst * snd, axis=-1)
374
+
375
+ neg = len < 0.0
376
+ len[neg] = -len[neg]
377
+ snd[neg] = -snd[neg]
378
+
379
+ amount0 = np.zeros(a.shape)
380
+ amount1 = np.zeros(a.shape)
381
+
382
+ linear = (1.0 - len) < 0.01
383
+ omegas = np.arccos(len[~linear])
384
+ sinoms = np.sin(omegas)
385
+
386
+ amount0[ linear] = 1.0 - a[linear]
387
+ amount1[ linear] = a[linear]
388
+ amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
389
+ amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms
390
+
391
+ return Quaternions(
392
+ amount0[...,np.newaxis] * fst +
393
+ amount1[...,np.newaxis] * snd)
394
+
395
+ @classmethod
396
+ def between(cls, v0s, v1s):
397
+ a = np.cross(v0s, v1s)
398
+ w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
399
+ return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()
400
+
401
+ @classmethod
402
+ def from_angle_axis(cls, angles, axis):
403
+ axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]
404
+ sines = np.sin(angles / 2.0)[...,np.newaxis]
405
+ cosines = np.cos(angles / 2.0)[...,np.newaxis]
406
+ return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
407
+
408
+ @classmethod
409
+ def from_euler(cls, es, order='xyz', world=False):
410
+
411
+ axis = {
412
+ 'x' : np.array([1,0,0]),
413
+ 'y' : np.array([0,1,0]),
414
+ 'z' : np.array([0,0,1]),
415
+ }
416
+
417
+ q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])
418
+ q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])
419
+ q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])
420
+
421
+ return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
422
+
423
+ @classmethod
424
+ def from_transforms(cls, ts):
425
+
426
+ d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]
427
+
428
+ q0 = ( d0 + d1 + d2 + 1.0) / 4.0
429
+ q1 = ( d0 - d1 - d2 + 1.0) / 4.0
430
+ q2 = (-d0 + d1 - d2 + 1.0) / 4.0
431
+ q3 = (-d0 - d1 + d2 + 1.0) / 4.0
432
+
433
+ q0 = np.sqrt(q0.clip(0,None))
434
+ q1 = np.sqrt(q1.clip(0,None))
435
+ q2 = np.sqrt(q2.clip(0,None))
436
+ q3 = np.sqrt(q3.clip(0,None))
437
+
438
+ c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
439
+ c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
440
+ c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
441
+ c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
442
+
443
+ q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])
444
+ q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])
445
+ q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])
446
+
447
+ q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])
448
+ q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])
449
+ q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])
450
+
451
+ q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])
452
+ q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])
453
+ q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])
454
+
455
+ q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])
456
+ q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])
457
+ q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])
458
+
459
+ qs = np.empty(ts.shape[:-2] + (4,))
460
+ qs[...,0] = q0
461
+ qs[...,1] = q1
462
+ qs[...,2] = q2
463
+ qs[...,3] = q3
464
+
465
+ return cls(qs)
466
+
467
+
468
+
dataloaders/pymo/__init__.py ADDED
File without changes
dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc ADDED
Binary file (28.3 kB). View file
 
dataloaders/pymo/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
dataloaders/pymo/__pycache__/data.cpython-312.pyc ADDED
Binary file (3.34 kB). View file
 
dataloaders/pymo/__pycache__/parsers.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc ADDED
Binary file (33.8 kB). View file
 
dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc ADDED
Binary file (8.27 kB). View file
 
dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
dataloaders/pymo/data.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Joint():
4
+ def __init__(self, name, parent=None, children=None):
5
+ self.name = name
6
+ self.parent = parent
7
+ self.children = children
8
+
9
+ class MocapData():
10
+ def __init__(self):
11
+ self.skeleton = {}
12
+ self.values = None
13
+ self.channel_names = []
14
+ self.framerate = 0.0
15
+ self.root_name = ''
16
+
17
+ def traverse(self, j=None):
18
+ stack = [self.root_name]
19
+ while stack:
20
+ joint = stack.pop()
21
+ yield joint
22
+ for c in self.skeleton[joint]['children']:
23
+ stack.append(c)
24
+
25
+ def clone(self):
26
+ import copy
27
+ new_data = MocapData()
28
+ new_data.skeleton = copy.copy(self.skeleton)
29
+ new_data.values = copy.copy(self.values)
30
+ new_data.channel_names = copy.copy(self.channel_names)
31
+ new_data.root_name = copy.copy(self.root_name)
32
+ new_data.framerate = copy.copy(self.framerate)
33
+ return new_data
34
+
35
+ def get_all_channels(self):
36
+ '''Returns all of the channels parsed from the file as a 2D numpy array'''
37
+
38
+ frames = [f[1] for f in self.values]
39
+ return np.asarray([[channel[2] for channel in frame] for frame in frames])
40
+
41
+ def get_skeleton_tree(self):
42
+ tree = []
43
+ root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0]
44
+
45
+ root_joint = Joint(root_key)
46
+
47
+ def get_empty_channels(self):
48
+ #TODO
49
+ pass
50
+
51
+ def get_constant_channels(self):
52
+ #TODO
53
+ pass
dataloaders/pymo/features.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ A set of mocap feature extraction functions
3
+
4
+ Created by Omid Alemi | Nov 17 2017
5
+
6
+ '''
7
+ import numpy as np
8
+ import pandas as pd
9
+ import peakutils
10
+ import matplotlib.pyplot as plt
11
+
12
+ def get_foot_contact_idxs(signal, t=0.02, min_dist=120):
13
+ up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist)
14
+ down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist)
15
+
16
+ return [up_idxs, down_idxs]
17
+
18
+
19
+ def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120):
20
+ signal = mocap_track.values[col_name].values
21
+ idxs = get_foot_contact_idxs(signal, t, min_dist)
22
+
23
+ step_signal = []
24
+
25
+ c = start
26
+ for f in range(len(signal)):
27
+ if f in idxs[1]:
28
+ c = 0
29
+ elif f in idxs[0]:
30
+ c = 1
31
+
32
+ step_signal.append(c)
33
+
34
+ return step_signal
35
+
36
+ def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120):
37
+
38
+ signal = mocap_track.values[col_name].values
39
+ idxs = get_foot_contact_idxs(signal, t, min_dist)
40
+
41
+ plt.plot(mocap_track.values.index, signal)
42
+ plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro')
43
+ plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go')
dataloaders/pymo/parsers.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ BVH Parser Class
3
+
4
+ By Omid Alemi
5
+ Created: June 12, 2017
6
+
7
+ Based on: https://gist.github.com/johnfredcee/2007503
8
+
9
+ '''
10
+ import re
11
+ from unicodedata import name
12
+ import numpy as np
13
+ from .data import Joint, MocapData
14
+
15
+ class BVHScanner():
16
+ '''
17
+ A wrapper class for re.Scanner
18
+ '''
19
+ def __init__(self):
20
+
21
+ def identifier(scanner, token):
22
+ return 'IDENT', token
23
+
24
+ def operator(scanner, token):
25
+ return 'OPERATOR', token
26
+
27
+ def digit(scanner, token):
28
+ return 'DIGIT', token
29
+
30
+ def open_brace(scanner, token):
31
+ return 'OPEN_BRACE', token
32
+
33
+ def close_brace(scanner, token):
34
+ return 'CLOSE_BRACE', token
35
+
36
+ self.scanner = re.Scanner([
37
+ (r'[a-zA-Z_]\w*', identifier),
38
+ #(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34
39
+ #(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2
40
+ #(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
41
+ (r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
42
+ (r'}', close_brace),
43
+ (r'}', close_brace),
44
+ (r'{', open_brace),
45
+ (r':', None),
46
+ (r'\s+', None)
47
+ ])
48
+
49
+ def scan(self, stuff):
50
+ return self.scanner.scan(stuff)
51
+
52
+
53
+
54
+ class BVHParser():
55
+ '''
56
+ A class to parse a BVH file.
57
+
58
+ Extracts the skeleton and channel values
59
+ '''
60
+ def __init__(self, filename=None):
61
+ self.reset()
62
+
63
+ def reset(self):
64
+ self._skeleton = {}
65
+ self.bone_context = []
66
+ self._motion_channels = []
67
+ self._motions = []
68
+ self.current_token = 0
69
+ self.framerate = 0.0
70
+ self.root_name = ''
71
+
72
+ self.scanner = BVHScanner()
73
+
74
+ self.data = MocapData()
75
+
76
+
77
+ def parse(self, filename, start=0, stop=-1):
78
+ self.reset()
79
+ self.correct_row_num = 0
80
+ with open(filename, 'r') as f:
81
+ for line in f.readlines():
82
+ self.correct_row_num += 1
83
+
84
+ with open(filename, 'r') as bvh_file:
85
+ raw_contents = bvh_file.read()
86
+ tokens, remainder = self.scanner.scan(raw_contents)
87
+
88
+ self._parse_hierarchy(tokens)
89
+ self.current_token = self.current_token + 1
90
+ self._parse_motion(tokens, start, stop)
91
+
92
+ self.data.skeleton = self._skeleton
93
+ self.data.channel_names = self._motion_channels
94
+ self.data.values = self._to_DataFrame()
95
+ self.data.root_name = self.root_name
96
+ self.data.framerate = self.framerate
97
+
98
+ return self.data
99
+
100
+ def _to_DataFrame(self):
101
+ '''Returns all of the channels parsed from the file as a pandas DataFrame'''
102
+
103
+ import pandas as pd
104
+ time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s')
105
+ frames = [f[1] for f in self._motions]
106
+ channels = np.asarray([[channel[2] for channel in frame] for frame in frames])
107
+ column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels]
108
+
109
+ return pd.DataFrame(data=channels, index=time_index, columns=column_names)
110
+
111
+
112
+ def _new_bone(self, parent, name):
113
+ bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []}
114
+ return bone
115
+
116
+ def _push_bone_context(self,name):
117
+ self.bone_context.append(name)
118
+
119
+ def _get_bone_context(self):
120
+ return self.bone_context[len(self.bone_context)-1]
121
+
122
+ def _pop_bone_context(self):
123
+ self.bone_context = self.bone_context[:-1]
124
+ return self.bone_context[len(self.bone_context)-1]
125
+
126
+ def _read_offset(self, bvh, token_index):
127
+ if bvh[token_index] != ('IDENT', 'OFFSET'):
128
+ return None, None
129
+ token_index = token_index + 1
130
+ offsets = [0.0] * 3
131
+ for i in range(3):
132
+ offsets[i] = float(bvh[token_index][1])
133
+ token_index = token_index + 1
134
+ return offsets, token_index
135
+
136
+ def _read_channels(self, bvh, token_index):
137
+ if bvh[token_index] != ('IDENT', 'CHANNELS'):
138
+ return None, None
139
+ token_index = token_index + 1
140
+ channel_count = int(bvh[token_index][1])
141
+ token_index = token_index + 1
142
+ channels = [""] * channel_count
143
+ order = ""
144
+ for i in range(channel_count):
145
+ channels[i] = bvh[token_index][1]
146
+ token_index = token_index + 1
147
+ if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"):
148
+ order += channels[i][0]
149
+ else :
150
+ order = ""
151
+ return channels, token_index, order
152
+
153
+ def _parse_joint(self, bvh, token_index):
154
+ end_site = False
155
+ joint_id = bvh[token_index][1]
156
+ token_index = token_index + 1
157
+ joint_name = bvh[token_index][1]
158
+ token_index = token_index + 1
159
+
160
+ parent_name = self._get_bone_context()
161
+
162
+ if (joint_id == "End"):
163
+ joint_name = parent_name+ '_Nub'
164
+ end_site = True
165
+ joint = self._new_bone(parent_name, joint_name)
166
+ if bvh[token_index][0] != 'OPEN_BRACE':
167
+ print('Was expecting brance, got ', bvh[token_index])
168
+ return None
169
+ token_index = token_index + 1
170
+ offsets, token_index = self._read_offset(bvh, token_index)
171
+ joint['offsets'] = offsets
172
+ if not end_site:
173
+ channels, token_index, order = self._read_channels(bvh, token_index)
174
+ joint['channels'] = channels
175
+ joint['order'] = order
176
+ for channel in channels:
177
+ self._motion_channels.append((joint_name, channel))
178
+
179
+ self._skeleton[joint_name] = joint
180
+ self._skeleton[parent_name]['children'].append(joint_name)
181
+
182
+ while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'):
183
+ self._push_bone_context(joint_name)
184
+ token_index = self._parse_joint(bvh, token_index)
185
+ self._pop_bone_context()
186
+
187
+ if bvh[token_index][0] == 'CLOSE_BRACE':
188
+ return token_index + 1
189
+
190
+ print('Unexpected token ', bvh[token_index])
191
+
192
+ def _parse_hierarchy(self, bvh):
193
+ self.current_token = 0
194
+ if bvh[self.current_token] != ('IDENT', 'HIERARCHY'):
195
+ return None
196
+ self.current_token = self.current_token + 1
197
+ if bvh[self.current_token] != ('IDENT', 'ROOT'):
198
+ return None
199
+ self.current_token = self.current_token + 1
200
+ if bvh[self.current_token][0] != 'IDENT':
201
+ return None
202
+
203
+ root_name = bvh[self.current_token][1]
204
+ root_bone = self._new_bone(None, root_name)
205
+ self.current_token = self.current_token + 2 #skipping open brace
206
+ offsets, self.current_token = self._read_offset(bvh, self.current_token)
207
+ channels, self.current_token, order = self._read_channels(bvh, self.current_token)
208
+ root_bone['offsets'] = offsets
209
+ root_bone['channels'] = channels
210
+ root_bone['order'] = order
211
+ self._skeleton[root_name] = root_bone
212
+ self._push_bone_context(root_name)
213
+
214
+ for channel in channels:
215
+ self._motion_channels.append((root_name, channel))
216
+
217
+ while bvh[self.current_token][1] == 'JOINT':
218
+ self.current_token = self._parse_joint(bvh, self.current_token)
219
+
220
+ self.root_name = root_name
221
+
222
+ def _parse_motion(self, bvh, start, stop):
223
+ if bvh[self.current_token][0] != 'IDENT':
224
+ print('Unexpected text')
225
+ return None
226
+ if bvh[self.current_token][1] != 'MOTION':
227
+ print('No motion section')
228
+ return None
229
+ self.current_token = self.current_token + 1
230
+ if bvh[self.current_token][1] != 'Frames':
231
+ return None
232
+ self.current_token = self.current_token + 1
233
+ frame_count = int(bvh[self.current_token][1])
234
+
235
+ if stop<0 or stop>frame_count:
236
+ stop = min(frame_count, self.correct_row_num-431)
237
+
238
+ assert(start>=0)
239
+ assert(start<stop)
240
+
241
+ self.current_token = self.current_token + 1
242
+ if bvh[self.current_token][1] != 'Frame':
243
+ return None
244
+ self.current_token = self.current_token + 1
245
+ if bvh[self.current_token][1] != 'Time':
246
+ return None
247
+ self.current_token = self.current_token + 1
248
+ frame_rate = float(bvh[self.current_token][1])
249
+
250
+ self.framerate = frame_rate
251
+
252
+ self.current_token = self.current_token + 1
253
+
254
+ frame_time = 0.0
255
+ self._motions = [()] * (stop-start)
256
+ idx=0
257
+ for i in range(stop):
258
+ #print(i)
259
+ channel_values = []
260
+
261
+ for channel in self._motion_channels:
262
+ #print(channel)
263
+ channel_values.append((channel[0], channel[1], float(bvh[self.current_token][1])))
264
+ self.current_token = self.current_token + 1
265
+
266
+ if i>=start:
267
+ self._motions[idx] = (frame_time, channel_values)
268
+ frame_time = frame_time + frame_rate
269
+ idx+=1
270
+
271
+
272
+ if __name__ == "__main__":
273
+ p = BVHParser()
274
+ data = [p.parse("../../../datasets/beat_full/2/2_scott_0_1_1.bvh")]
dataloaders/pymo/preprocessing.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Preprocessing Tranformers Based on sci-kit's API
3
+
4
+ By Omid Alemi
5
+ Created on June 12, 2017
6
+ '''
7
+ import copy
8
+ import pandas as pd
9
+ import numpy as np
10
+ from sklearn.base import BaseEstimator, TransformerMixin
11
+ from .Quaternions import Quaternions
12
+ from .rotation_tools import Rotation
13
+
14
+ class MocapParameterizer(BaseEstimator, TransformerMixin):
15
+ def __init__(self, param_type = 'euler'):
16
+ '''
17
+
18
+ param_type = {'euler', 'quat', 'expmap', 'position'}
19
+ '''
20
+ self.param_type = param_type
21
+
22
+ def fit(self, X, y=None):
23
+ return self
24
+
25
+ def transform(self, X, y=None):
26
+ if self.param_type == 'euler':
27
+ return X
28
+ elif self.param_type == 'expmap':
29
+ return self._to_expmap(X)
30
+ elif self.param_type == 'quat':
31
+ return X
32
+ elif self.param_type == 'position':
33
+ return self._to_pos(X)
34
+ else:
35
+ raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type)
36
+ # return X
37
+
38
+ def inverse_transform(self, X, copy=None):
39
+ if self.param_type == 'euler':
40
+ return X
41
+ elif self.param_type == 'expmap':
42
+ return self._expmap_to_euler(X)
43
+ elif self.param_type == 'quat':
44
+ raise UnsupportedParamError('quat2euler is not supported')
45
+ elif self.param_type == 'position':
46
+ print('positions 2 eulers is not supported')
47
+ return X
48
+ else:
49
+ raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type)
50
+
51
+ def _to_pos(self, X):
52
+ '''Converts joints rotations in Euler angles to joint positions'''
53
+
54
+ Q = []
55
+ for track in X:
56
+ channels = []
57
+ titles = []
58
+ euler_df = track.values
59
+
60
+ # Create a new DataFrame to store the exponential map rep
61
+ pos_df = pd.DataFrame(index=euler_df.index)
62
+
63
+ # Copy the root rotations into the new DataFrame
64
+ # rxp = '%s_Xrotation'%track.root_name
65
+ # ryp = '%s_Yrotation'%track.root_name
66
+ # rzp = '%s_Zrotation'%track.root_name
67
+ # pos_df[rxp] = pd.Series(data=euler_df[rxp], index=pos_df.index)
68
+ # pos_df[ryp] = pd.Series(data=euler_df[ryp], index=pos_df.index)
69
+ # pos_df[rzp] = pd.Series(data=euler_df[rzp], index=pos_df.index)
70
+
71
+ # List the columns that contain rotation channels
72
+ rot_cols = [c for c in euler_df.columns if ('rotation' in c)]
73
+
74
+ # List the columns that contain position channels
75
+ pos_cols = [c for c in euler_df.columns if ('position' in c)]
76
+
77
+ # List the joints that are not end sites, i.e., have channels
78
+ joints = (joint for joint in track.skeleton)
79
+
80
+ tree_data = {}
81
+
82
+ for joint in track.traverse():
83
+ parent = track.skeleton[joint]['parent']
84
+ rot_order = track.skeleton[joint]['order']
85
+ #print("rot_order:" + joint + " :" + rot_order)
86
+
87
+ # Get the rotation columns that belong to this joint
88
+ rc = euler_df[[c for c in rot_cols if joint in c]]
89
+
90
+ # Get the position columns that belong to this joint
91
+ pc = euler_df[[c for c in pos_cols if joint in c]]
92
+
93
+ # Make sure the columns are organized in xyz order
94
+ if rc.shape[1] < 3:
95
+ euler_values = np.zeros((euler_df.shape[0], 3))
96
+ rot_order = "XYZ"
97
+ else:
98
+ euler_values = np.pi/180.0*np.transpose(np.array([track.values['%s_%srotation'%(joint, rot_order[0])], track.values['%s_%srotation'%(joint, rot_order[1])], track.values['%s_%srotation'%(joint, rot_order[2])]]))
99
+
100
+ if pc.shape[1] < 3:
101
+ pos_values = np.asarray([[0,0,0] for f in pc.iterrows()])
102
+ else:
103
+ pos_values =np.asarray([[f[1]['%s_Xposition'%joint],
104
+ f[1]['%s_Yposition'%joint],
105
+ f[1]['%s_Zposition'%joint]] for f in pc.iterrows()])
106
+
107
+ quats = Quaternions.from_euler(np.asarray(euler_values), order=rot_order.lower(), world=False)
108
+
109
+ tree_data[joint]=[
110
+ [], # to store the rotation matrix
111
+ [] # to store the calculated position
112
+ ]
113
+ if track.root_name == joint:
114
+ tree_data[joint][0] = quats#rotmats
115
+ # tree_data[joint][1] = np.add(pos_values, track.skeleton[joint]['offsets'])
116
+ tree_data[joint][1] = pos_values
117
+ else:
118
+ # for every frame i, multiply this joint's rotmat to the rotmat of its parent
119
+ tree_data[joint][0] = tree_data[parent][0]*quats# np.matmul(rotmats, tree_data[parent][0])
120
+
121
+ # add the position channel to the offset and store it in k, for every frame i
122
+ k = pos_values + np.asarray(track.skeleton[joint]['offsets'])
123
+
124
+ # multiply k to the rotmat of the parent for every frame i
125
+ q = tree_data[parent][0]*k #np.matmul(k.reshape(k.shape[0],1,3), tree_data[parent][0])
126
+
127
+ # add q to the position of the parent, for every frame i
128
+ tree_data[joint][1] = tree_data[parent][1] + q #q.reshape(k.shape[0],3) + tree_data[parent][1]
129
+
130
+ # Create the corresponding columns in the new DataFrame
131
+ pos_df['%s_Xposition'%joint] = pd.Series(data=[e[0] for e in tree_data[joint][1]], index=pos_df.index)
132
+ pos_df['%s_Yposition'%joint] = pd.Series(data=[e[1] for e in tree_data[joint][1]], index=pos_df.index)
133
+ pos_df['%s_Zposition'%joint] = pd.Series(data=[e[2] for e in tree_data[joint][1]], index=pos_df.index)
134
+
135
+
136
+ new_track = track.clone()
137
+ new_track.values = pos_df
138
+ Q.append(new_track)
139
+ return Q
140
+
141
+
142
+ def _to_expmap(self, X):
143
+ '''Converts Euler angles to Exponential Maps'''
144
+
145
+ Q = []
146
+ for track in X:
147
+ channels = []
148
+ titles = []
149
+ euler_df = track.values
150
+
151
+ # Create a new DataFrame to store the exponential map rep
152
+ exp_df = pd.DataFrame(index=euler_df.index)
153
+
154
+ # Copy the root positions into the new DataFrame
155
+ rxp = '%s_Xposition'%track.root_name
156
+ ryp = '%s_Yposition'%track.root_name
157
+ rzp = '%s_Zposition'%track.root_name
158
+ exp_df[rxp] = pd.Series(data=euler_df[rxp], index=exp_df.index)
159
+ exp_df[ryp] = pd.Series(data=euler_df[ryp], index=exp_df.index)
160
+ exp_df[rzp] = pd.Series(data=euler_df[rzp], index=exp_df.index)
161
+
162
+ # List the columns that contain rotation channels
163
+ rots = [c for c in euler_df.columns if ('rotation' in c and 'Nub' not in c)]
164
+
165
+ # List the joints that are not end sites, i.e., have channels
166
+ joints = (joint for joint in track.skeleton if 'Nub' not in joint)
167
+
168
+ for joint in joints:
169
+ r = euler_df[[c for c in rots if joint in c]] # Get the columns that belong to this joint
170
+ euler = [[f[1]['%s_Xrotation'%joint], f[1]['%s_Yrotation'%joint], f[1]['%s_Zrotation'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order
171
+ exps = [Rotation(f, 'euler', from_deg=True).to_expmap() for f in euler] # Convert the eulers to exp maps
172
+
173
+ # Create the corresponding columns in the new DataFrame
174
+
175
+ exp_df['%s_alpha'%joint] = pd.Series(data=[e[0] for e in exps], index=exp_df.index)
176
+ exp_df['%s_beta'%joint] = pd.Series(data=[e[1] for e in exps], index=exp_df.index)
177
+ exp_df['%s_gamma'%joint] = pd.Series(data=[e[2] for e in exps], index=exp_df.index)
178
+
179
+ new_track = track.clone()
180
+ new_track.values = exp_df
181
+ Q.append(new_track)
182
+
183
+ return Q
184
+
185
+ def _expmap_to_euler(self, X):
186
+ Q = []
187
+ for track in X:
188
+ channels = []
189
+ titles = []
190
+ exp_df = track.values
191
+
192
+ # Create a new DataFrame to store the exponential map rep
193
+ euler_df = pd.DataFrame(index=exp_df.index)
194
+
195
+ # Copy the root positions into the new DataFrame
196
+ rxp = '%s_Xposition'%track.root_name
197
+ ryp = '%s_Yposition'%track.root_name
198
+ rzp = '%s_Zposition'%track.root_name
199
+ euler_df[rxp] = pd.Series(data=exp_df[rxp], index=euler_df.index)
200
+ euler_df[ryp] = pd.Series(data=exp_df[ryp], index=euler_df.index)
201
+ euler_df[rzp] = pd.Series(data=exp_df[rzp], index=euler_df.index)
202
+
203
+ # List the columns that contain rotation channels
204
+ exp_params = [c for c in exp_df.columns if ( any(p in c for p in ['alpha', 'beta','gamma']) and 'Nub' not in c)]
205
+
206
+ # List the joints that are not end sites, i.e., have channels
207
+ joints = (joint for joint in track.skeleton if 'Nub' not in joint)
208
+
209
+ for joint in joints:
210
+ r = exp_df[[c for c in exp_params if joint in c]] # Get the columns that belong to this joint
211
+ expmap = [[f[1]['%s_alpha'%joint], f[1]['%s_beta'%joint], f[1]['%s_gamma'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order
212
+ euler_rots = [Rotation(f, 'expmap').to_euler(True)[0] for f in expmap] # Convert the eulers to exp maps
213
+
214
+ # Create the corresponding columns in the new DataFrame
215
+
216
+ euler_df['%s_Xrotation'%joint] = pd.Series(data=[e[0] for e in euler_rots], index=euler_df.index)
217
+ euler_df['%s_Yrotation'%joint] = pd.Series(data=[e[1] for e in euler_rots], index=euler_df.index)
218
+ euler_df['%s_Zrotation'%joint] = pd.Series(data=[e[2] for e in euler_rots], index=euler_df.index)
219
+
220
+ new_track = track.clone()
221
+ new_track.values = euler_df
222
+ Q.append(new_track)
223
+
224
+ return Q
225
+
226
+
227
+ class JointSelector(BaseEstimator, TransformerMixin):
228
+ '''
229
+ Allows for filtering the mocap data to include only the selected joints
230
+ '''
231
+ def __init__(self, joints, include_root=False):
232
+ self.joints = joints
233
+ self.include_root = include_root
234
+
235
+ def fit(self, X, y=None):
236
+ return self
237
+
238
+ def transform(self, X, y=None):
239
+ selected_joints = []
240
+ selected_channels = []
241
+
242
+ if self.include_root:
243
+ selected_joints.append(X[0].root_name)
244
+
245
+ selected_joints.extend(self.joints)
246
+
247
+ for joint_name in selected_joints:
248
+ selected_channels.extend([o for o in X[0].values.columns if joint_name in o])
249
+
250
+ Q = []
251
+
252
+
253
+ for track in X:
254
+ t2 = track.clone()
255
+
256
+ for key in track.skeleton.keys():
257
+ if key not in selected_joints:
258
+ t2.skeleton.pop(key)
259
+ t2.values = track.values[selected_channels]
260
+
261
+ Q.append(t2)
262
+
263
+
264
+ return Q
265
+
266
+
267
+ class Numpyfier(BaseEstimator, TransformerMixin):
268
+ '''
269
+ Just converts the values in a MocapData object into a numpy array
270
+ Useful for the final stage of a pipeline before training
271
+ '''
272
+ def __init__(self):
273
+ pass
274
+
275
+ def fit(self, X, y=None):
276
+ self.org_mocap_ = X[0].clone()
277
+ self.org_mocap_.values.drop(self.org_mocap_.values.index, inplace=True)
278
+
279
+ return self
280
+
281
+ def transform(self, X, y=None):
282
+ Q = []
283
+
284
+ for track in X:
285
+ Q.append(track.values.values)
286
+
287
+ return np.array(Q)
288
+
289
+ def inverse_transform(self, X, copy=None):
290
+ Q = []
291
+
292
+ for track in X:
293
+
294
+ new_mocap = self.org_mocap_.clone()
295
+ time_index = pd.to_timedelta([f for f in range(track.shape[0])], unit='s')
296
+
297
+ new_df = pd.DataFrame(data=track, index=time_index, columns=self.org_mocap_.values.columns)
298
+
299
+ new_mocap.values = new_df
300
+
301
+
302
+ Q.append(new_mocap)
303
+
304
+ return Q
305
+
306
+ class RootTransformer(BaseEstimator, TransformerMixin):
307
+ def __init__(self, method):
308
+ """
309
+ Accepted methods:
310
+ abdolute_translation_deltas
311
+ pos_rot_deltas
312
+ """
313
+ self.method = method
314
+
315
+ def fit(self, X, y=None):
316
+ return self
317
+
318
+ def transform(self, X, y=None):
319
+ Q = []
320
+
321
+ for track in X:
322
+ if self.method == 'abdolute_translation_deltas':
323
+ new_df = track.values.copy()
324
+ xpcol = '%s_Xposition'%track.root_name
325
+ ypcol = '%s_Yposition'%track.root_name
326
+ zpcol = '%s_Zposition'%track.root_name
327
+
328
+
329
+ dxpcol = '%s_dXposition'%track.root_name
330
+ dzpcol = '%s_dZposition'%track.root_name
331
+
332
+ dx = track.values[xpcol].diff()
333
+ dz = track.values[zpcol].diff()
334
+
335
+ dx[0] = 0
336
+ dz[0] = 0
337
+
338
+ new_df.drop([xpcol, zpcol], axis=1, inplace=True)
339
+
340
+ new_df[dxpcol] = dx
341
+ new_df[dzpcol] = dz
342
+
343
+ new_track = track.clone()
344
+ new_track.values = new_df
345
+ # end of abdolute_translation_deltas
346
+
347
+ elif self.method == 'pos_rot_deltas':
348
+ new_track = track.clone()
349
+
350
+ # Absolute columns
351
+ xp_col = '%s_Xposition'%track.root_name
352
+ yp_col = '%s_Yposition'%track.root_name
353
+ zp_col = '%s_Zposition'%track.root_name
354
+
355
+ xr_col = '%s_Xrotation'%track.root_name
356
+ yr_col = '%s_Yrotation'%track.root_name
357
+ zr_col = '%s_Zrotation'%track.root_name
358
+
359
+ # Delta columns
360
+ dxp_col = '%s_dXposition'%track.root_name
361
+ dzp_col = '%s_dZposition'%track.root_name
362
+
363
+ dxr_col = '%s_dXrotation'%track.root_name
364
+ dyr_col = '%s_dYrotation'%track.root_name
365
+ dzr_col = '%s_dZrotation'%track.root_name
366
+
367
+
368
+ new_df = track.values.copy()
369
+
370
+ root_pos_x_diff = pd.Series(data=track.values[xp_col].diff(), index=new_df.index)
371
+ root_pos_z_diff = pd.Series(data=track.values[zp_col].diff(), index=new_df.index)
372
+
373
+ root_rot_y_diff = pd.Series(data=track.values[yr_col].diff(), index=new_df.index)
374
+ root_rot_x_diff = pd.Series(data=track.values[xr_col].diff(), index=new_df.index)
375
+ root_rot_z_diff = pd.Series(data=track.values[zr_col].diff(), index=new_df.index)
376
+
377
+
378
+ root_pos_x_diff[0] = 0
379
+ root_pos_z_diff[0] = 0
380
+
381
+ root_rot_y_diff[0] = 0
382
+ root_rot_x_diff[0] = 0
383
+ root_rot_z_diff[0] = 0
384
+
385
+ new_df.drop([xr_col, yr_col, zr_col, xp_col, zp_col], axis=1, inplace=True)
386
+
387
+ new_df[dxp_col] = root_pos_x_diff
388
+ new_df[dzp_col] = root_pos_z_diff
389
+
390
+ new_df[dxr_col] = root_rot_x_diff
391
+ new_df[dyr_col] = root_rot_y_diff
392
+ new_df[dzr_col] = root_rot_z_diff
393
+
394
+ new_track.values = new_df
395
+
396
+ Q.append(new_track)
397
+
398
+ return Q
399
+
400
+ def inverse_transform(self, X, copy=None, start_pos=None):
401
+ Q = []
402
+
403
+ #TODO: simplify this implementation
404
+
405
+ startx = 0
406
+ startz = 0
407
+
408
+ if start_pos is not None:
409
+ startx, startz = start_pos
410
+
411
+ for track in X:
412
+ new_track = track.clone()
413
+ if self.method == 'abdolute_translation_deltas':
414
+ new_df = new_track.values
415
+ xpcol = '%s_Xposition'%track.root_name
416
+ ypcol = '%s_Yposition'%track.root_name
417
+ zpcol = '%s_Zposition'%track.root_name
418
+
419
+
420
+ dxpcol = '%s_dXposition'%track.root_name
421
+ dzpcol = '%s_dZposition'%track.root_name
422
+
423
+ dx = track.values[dxpcol].values
424
+ dz = track.values[dzpcol].values
425
+
426
+ recx = [startx]
427
+ recz = [startz]
428
+
429
+ for i in range(dx.shape[0]-1):
430
+ recx.append(recx[i]+dx[i+1])
431
+ recz.append(recz[i]+dz[i+1])
432
+
433
+ # recx = [recx[i]+dx[i+1] for i in range(dx.shape[0]-1)]
434
+ # recz = [recz[i]+dz[i+1] for i in range(dz.shape[0]-1)]
435
+ # recx = dx[:-1] + dx[1:]
436
+ # recz = dz[:-1] + dz[1:]
437
+
438
+ new_df[xpcol] = pd.Series(data=recx, index=new_df.index)
439
+ new_df[zpcol] = pd.Series(data=recz, index=new_df.index)
440
+
441
+ new_df.drop([dxpcol, dzpcol], axis=1, inplace=True)
442
+
443
+ new_track.values = new_df
444
+ # end of abdolute_translation_deltas
445
+
446
+ elif self.method == 'pos_rot_deltas':
447
+ new_track = track.clone()
448
+
449
+ # Absolute columns
450
+ xp_col = '%s_Xposition'%track.root_name
451
+ yp_col = '%s_Yposition'%track.root_name
452
+ zp_col = '%s_Zposition'%track.root_name
453
+
454
+ xr_col = '%s_Xrotation'%track.root_name
455
+ yr_col = '%s_Yrotation'%track.root_name
456
+ zr_col = '%s_Zrotation'%track.root_name
457
+
458
+ # Delta columns
459
+ dxp_col = '%s_dXposition'%track.root_name
460
+ dzp_col = '%s_dZposition'%track.root_name
461
+
462
+ dxr_col = '%s_dXrotation'%track.root_name
463
+ dyr_col = '%s_dYrotation'%track.root_name
464
+ dzr_col = '%s_dZrotation'%track.root_name
465
+
466
+
467
+ new_df = track.values.copy()
468
+
469
+ dx = track.values[dxp_col].values
470
+ dz = track.values[dzp_col].values
471
+
472
+ drx = track.values[dxr_col].values
473
+ dry = track.values[dyr_col].values
474
+ drz = track.values[dzr_col].values
475
+
476
+ rec_xp = [startx]
477
+ rec_zp = [startz]
478
+
479
+ rec_xr = [0]
480
+ rec_yr = [0]
481
+ rec_zr = [0]
482
+
483
+
484
+ for i in range(dx.shape[0]-1):
485
+ rec_xp.append(rec_xp[i]+dx[i+1])
486
+ rec_zp.append(rec_zp[i]+dz[i+1])
487
+
488
+ rec_xr.append(rec_xr[i]+drx[i+1])
489
+ rec_yr.append(rec_yr[i]+dry[i+1])
490
+ rec_zr.append(rec_zr[i]+drz[i+1])
491
+
492
+
493
+ new_df[xp_col] = pd.Series(data=rec_xp, index=new_df.index)
494
+ new_df[zp_col] = pd.Series(data=rec_zp, index=new_df.index)
495
+
496
+ new_df[xr_col] = pd.Series(data=rec_xr, index=new_df.index)
497
+ new_df[yr_col] = pd.Series(data=rec_yr, index=new_df.index)
498
+ new_df[zr_col] = pd.Series(data=rec_zr, index=new_df.index)
499
+
500
+ new_df.drop([dxr_col, dyr_col, dzr_col, dxp_col, dzp_col], axis=1, inplace=True)
501
+
502
+
503
+ new_track.values = new_df
504
+
505
+ Q.append(new_track)
506
+
507
+ return Q
508
+
509
+
510
+ class RootCentricPositionNormalizer(BaseEstimator, TransformerMixin):
511
+ def __init__(self):
512
+ pass
513
+
514
+ def fit(self, X, y=None):
515
+ return self
516
+
517
+ def transform(self, X, y=None):
518
+ Q = []
519
+
520
+ for track in X:
521
+ new_track = track.clone()
522
+
523
+ rxp = '%s_Xposition'%track.root_name
524
+ ryp = '%s_Yposition'%track.root_name
525
+ rzp = '%s_Zposition'%track.root_name
526
+
527
+ projected_root_pos = track.values[[rxp, ryp, rzp]]
528
+
529
+ projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref
530
+
531
+ new_df = pd.DataFrame(index=track.values.index)
532
+
533
+ all_but_root = [joint for joint in track.skeleton if track.root_name not in joint]
534
+ # all_but_root = [joint for joint in track.skeleton]
535
+ for joint in all_but_root:
536
+ new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]-projected_root_pos[rxp], index=new_df.index)
537
+ new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]-projected_root_pos[ryp], index=new_df.index)
538
+ new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]-projected_root_pos[rzp], index=new_df.index)
539
+
540
+
541
+ # keep the root as it is now
542
+ new_df[rxp] = track.values[rxp]
543
+ new_df[ryp] = track.values[ryp]
544
+ new_df[rzp] = track.values[rzp]
545
+
546
+ new_track.values = new_df
547
+
548
+ Q.append(new_track)
549
+
550
+ return Q
551
+
552
+ def inverse_transform(self, X, copy=None):
553
+ Q = []
554
+
555
+ for track in X:
556
+ new_track = track.clone()
557
+
558
+ rxp = '%s_Xposition'%track.root_name
559
+ ryp = '%s_Yposition'%track.root_name
560
+ rzp = '%s_Zposition'%track.root_name
561
+
562
+ projected_root_pos = track.values[[rxp, ryp, rzp]]
563
+
564
+ projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref
565
+
566
+ new_df = pd.DataFrame(index=track.values.index)
567
+
568
+ for joint in track.skeleton:
569
+ new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]+projected_root_pos[rxp], index=new_df.index)
570
+ new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]+projected_root_pos[ryp], index=new_df.index)
571
+ new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]+projected_root_pos[rzp], index=new_df.index)
572
+
573
+
574
+ new_track.values = new_df
575
+
576
+ Q.append(new_track)
577
+
578
+ return Q
579
+
580
+
581
+ class Flattener(BaseEstimator, TransformerMixin):
582
+ def __init__(self):
583
+ pass
584
+
585
+ def fit(self, X, y=None):
586
+ return self
587
+
588
+ def transform(self, X, y=None):
589
+ return np.concatenate(X, axis=0)
590
+
591
+ class ConstantsRemover(BaseEstimator, TransformerMixin):
592
+ '''
593
+ For now it just looks at the first track
594
+ '''
595
+
596
+ def __init__(self, eps = 10e-10):
597
+ self.eps = eps
598
+
599
+
600
+ def fit(self, X, y=None):
601
+ stds = X[0].values.std()
602
+ cols = X[0].values.columns.values
603
+ self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()]
604
+ self.const_values_ = {c:X[0].values[c].values[0] for c in cols if (stds[c] < self.eps).any()}
605
+ return self
606
+
607
+ def transform(self, X, y=None):
608
+ Q = []
609
+
610
+
611
+ for track in X:
612
+ t2 = track.clone()
613
+ #for key in t2.skeleton.keys():
614
+ # if key in self.ConstDims_:
615
+ # t2.skeleton.pop(key)
616
+ t2.values = track.values[track.values.columns.difference(self.const_dims_)]
617
+ Q.append(t2)
618
+
619
+ return Q
620
+
621
+ def inverse_transform(self, X, copy=None):
622
+ Q = []
623
+
624
+ for track in X:
625
+ t2 = track.clone()
626
+ for d in self.const_dims_:
627
+ t2.values[d] = self.const_values_[d]
628
+ Q.append(t2)
629
+
630
+ return Q
631
+
632
+ class ListStandardScaler(BaseEstimator, TransformerMixin):
633
+ def __init__(self, is_DataFrame=False):
634
+ self.is_DataFrame = is_DataFrame
635
+
636
+ def fit(self, X, y=None):
637
+ if self.is_DataFrame:
638
+ X_train_flat = np.concatenate([m.values for m in X], axis=0)
639
+ else:
640
+ X_train_flat = np.concatenate([m for m in X], axis=0)
641
+
642
+ self.data_mean_ = np.mean(X_train_flat, axis=0)
643
+ self.data_std_ = np.std(X_train_flat, axis=0)
644
+
645
+ return self
646
+
647
+ def transform(self, X, y=None):
648
+ Q = []
649
+
650
+ for track in X:
651
+ if self.is_DataFrame:
652
+ normalized_track = track.copy()
653
+ normalized_track.values = (track.values - self.data_mean_) / self.data_std_
654
+ else:
655
+ normalized_track = (track - self.data_mean_) / self.data_std_
656
+
657
+ Q.append(normalized_track)
658
+
659
+ if self.is_DataFrame:
660
+ return Q
661
+ else:
662
+ return np.array(Q)
663
+
664
+ def inverse_transform(self, X, copy=None):
665
+ Q = []
666
+
667
+ for track in X:
668
+
669
+ if self.is_DataFrame:
670
+ unnormalized_track = track.copy()
671
+ unnormalized_track.values = (track.values * self.data_std_) + self.data_mean_
672
+ else:
673
+ unnormalized_track = (track * self.data_std_) + self.data_mean_
674
+
675
+ Q.append(unnormalized_track)
676
+
677
+ if self.is_DataFrame:
678
+ return Q
679
+ else:
680
+ return np.array(Q)
681
+
682
+ class DownSampler(BaseEstimator, TransformerMixin):
683
+ def __init__(self, rate):
684
+ self.rate = rate
685
+
686
+
687
+ def fit(self, X, y=None):
688
+
689
+ return self
690
+
691
+ def transform(self, X, y=None):
692
+ Q = []
693
+
694
+ for track in X:
695
+ #print(track.values.size)
696
+ #new_track = track.clone()
697
+ #new_track.values = track.values[0:-1:self.rate]
698
+ #print(new_track.values.size)
699
+ new_track = track[0:-1:self.rate]
700
+ Q.append(new_track)
701
+
702
+ return Q
703
+
704
+ def inverse_transform(self, X, copy=None):
705
+ return X
706
+
707
+
708
+ #TODO: JointsSelector (x)
709
+ #TODO: SegmentMaker
710
+ #TODO: DynamicFeaturesAdder
711
+ #TODO: ShapeFeaturesAdder
712
+ #TODO: DataFrameNumpier (x)
713
+
714
+ class TemplateTransform(BaseEstimator, TransformerMixin):
715
+ def __init__(self):
716
+ pass
717
+
718
+ def fit(self, X, y=None):
719
+ return self
720
+
721
+ def transform(self, X, y=None):
722
+ return X
723
+
724
+ class UnsupportedParamError(Exception):
725
+ def __init__(self, message):
726
+ self.message = message
dataloaders/pymo/rotation_tools.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Tools for Manipulating and Converting 3D Rotations
3
+
4
+ By Omid Alemi
5
+ Created: June 12, 2017
6
+
7
+ Adapted from that matlab file...
8
+ '''
9
+
10
+ import math
11
+ import numpy as np
12
+
13
+ def deg2rad(x):
14
+ return x/180*math.pi
15
+
16
+
17
+ def rad2deg(x):
18
+ return x/math.pi*180
19
+
20
+ class Rotation():
21
+ def __init__(self,rot, param_type, rotation_order, **params):
22
+ self.rotmat = []
23
+ self.rotation_order = rotation_order
24
+ if param_type == 'euler':
25
+ self._from_euler(rot[0],rot[1],rot[2], params)
26
+ elif param_type == 'expmap':
27
+ self._from_expmap(rot[0], rot[1], rot[2], params)
28
+
29
+ def _from_euler(self, alpha, beta, gamma, params):
30
+ '''Expecting degress'''
31
+
32
+ if params['from_deg']==True:
33
+ alpha = deg2rad(alpha)
34
+ beta = deg2rad(beta)
35
+ gamma = deg2rad(gamma)
36
+
37
+ ca = math.cos(alpha)
38
+ cb = math.cos(beta)
39
+ cg = math.cos(gamma)
40
+ sa = math.sin(alpha)
41
+ sb = math.sin(beta)
42
+ sg = math.sin(gamma)
43
+
44
+ Rx = np.asarray([[1, 0, 0],
45
+ [0, ca, sa],
46
+ [0, -sa, ca]
47
+ ])
48
+
49
+ Ry = np.asarray([[cb, 0, -sb],
50
+ [0, 1, 0],
51
+ [sb, 0, cb]])
52
+
53
+ Rz = np.asarray([[cg, sg, 0],
54
+ [-sg, cg, 0],
55
+ [0, 0, 1]])
56
+
57
+ self.rotmat = np.eye(3)
58
+
59
+ ############################ inner product rotation matrix in order defined at BVH file #########################
60
+ for axis in self.rotation_order :
61
+ if axis == 'X' :
62
+ self.rotmat = np.matmul(Rx, self.rotmat)
63
+ elif axis == 'Y':
64
+ self.rotmat = np.matmul(Ry, self.rotmat)
65
+ else :
66
+ self.rotmat = np.matmul(Rz, self.rotmat)
67
+ ################################################################################################################
68
+
69
+ def _from_expmap(self, alpha, beta, gamma, params):
70
+ if (alpha == 0 and beta == 0 and gamma == 0):
71
+ self.rotmat = np.eye(3)
72
+ return
73
+
74
+ #TODO: Check exp map params
75
+
76
+ theta = np.linalg.norm([alpha, beta, gamma])
77
+
78
+ expmap = [alpha, beta, gamma] / theta
79
+
80
+ x = expmap[0]
81
+ y = expmap[1]
82
+ z = expmap[2]
83
+
84
+ s = math.sin(theta/2)
85
+ c = math.cos(theta/2)
86
+
87
+ self.rotmat = np.asarray([
88
+ [2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s],
89
+ [2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s],
90
+ [2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1]
91
+ ])
92
+
93
+
94
+
95
+ def get_euler_axis(self):
96
+ R = self.rotmat
97
+ theta = math.acos((self.rotmat.trace() - 1) / 2)
98
+ axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]])
99
+ axis = axis/(2*math.sin(theta))
100
+ return theta, axis
101
+
102
+ def to_expmap(self):
103
+ theta, axis = self.get_euler_axis()
104
+ rot_arr = theta * axis
105
+ if np.isnan(rot_arr).any():
106
+ rot_arr = [0, 0, 0]
107
+ return rot_arr
108
+
109
+ def to_euler(self, use_deg=False):
110
+ eulers = np.zeros((2, 3))
111
+
112
+ if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12:
113
+ #GIMBAL LOCK!
114
+ print('Gimbal')
115
+ if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12:
116
+ eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2])
117
+ eulers[:,1] = -math.pi/2
118
+ else:
119
+ eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2])
120
+ eulers[:,1] = math.pi/2
121
+
122
+ return eulers
123
+
124
+ theta = - math.asin(self.rotmat[2,0])
125
+ theta2 = math.pi - theta
126
+
127
+ # psi1, psi2
128
+ eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta))
129
+ eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2))
130
+
131
+ # theta1, theta2
132
+ eulers[0,1] = theta
133
+ eulers[1,1] = theta2
134
+
135
+ # phi1, phi2
136
+ eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta))
137
+ eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2))
138
+
139
+ if use_deg:
140
+ eulers = rad2deg(eulers)
141
+
142
+ return eulers
143
+
144
+ def to_quat(self):
145
+ #TODO
146
+ pass
147
+
148
+ def __str__(self):
149
+ return "Rotation Matrix: \n " + self.rotmat.__str__()
150
+
151
+
152
+
153
+
dataloaders/pymo/rotation_tools.py! ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Tools for Manipulating and Converting 3D Rotations
3
+
4
+ By Omid Alemi
5
+ Created: June 12, 2017
6
+
7
+ Adapted from that matlab file...
8
+ '''
9
+
10
+ import math
11
+ import numpy as np
12
+
13
+ def deg2rad(x):
14
+ return x/180*math.pi
15
+
16
+ class Rotation():
17
+ def __init__(self,rot, param_type, **params):
18
+ self.rotmat = []
19
+ if param_type == 'euler':
20
+ self._from_euler(rot[0],rot[1],rot[2], params)
21
+
22
+ def _from_euler(self, alpha, beta, gamma, params):
23
+ '''Expecting degress'''
24
+
25
+ if params['from_deg']==True:
26
+ alpha = deg2rad(alpha)
27
+ beta = deg2rad(beta)
28
+ gamma = deg2rad(gamma)
29
+
30
+ Rx = np.asarray([[1, 0, 0],
31
+ [0, math.cos(alpha), -math.sin(alpha)],
32
+ [0, math.sin(alpha), math.cos(alpha)]
33
+ ])
34
+
35
+ Ry = np.asarray([[math.cos(beta), 0, math.sin(beta)],
36
+ [0, 1, 0],
37
+ [-math.sin(beta), 0, math.cos(beta)]])
38
+
39
+ Rz = np.asarray([[math.cos(gamma), -math.sin(gamma), 0],
40
+ [math.sin(gamma), math.cos(gamma), 0],
41
+ [0, 0, 1]])
42
+
43
+ self.rotmat = np.matmul(np.matmul(Rz, Ry), Rx).T
44
+
45
+ def get_euler_axis(self):
46
+ R = self.rotmat
47
+ theta = math.acos((self.rotmat.trace() - 1) / 2)
48
+ axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]])
49
+ axis = axis/(2*math.sin(theta))
50
+ return theta, axis
51
+
52
+ def to_expmap(self):
53
+ theta, axis = self.get_euler_axis()
54
+ rot_arr = theta * axis
55
+ if np.isnan(rot_arr).any():
56
+ rot_arr = [0, 0, 0]
57
+ return rot_arr
58
+
59
+ def to_euler(self):
60
+ #TODO
61
+ pass
62
+
63
+ def to_quat(self):
64
+ #TODO
65
+ pass
66
+
67
+
68
+
69
+
dataloaders/pymo/viz_tools.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import IPython
5
+ import os
6
+
7
+ def save_fig(fig_id, tight_layout=True):
8
+ if tight_layout:
9
+ plt.tight_layout()
10
+ plt.savefig(fig_id + '.png', format='png', dpi=300)
11
+
12
+
13
+ def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
14
+ if ax is None:
15
+ fig = plt.figure(figsize=figsize)
16
+ ax = fig.add_subplot(111)
17
+
18
+ if joints is None:
19
+ joints_to_draw = mocap_track.skeleton.keys()
20
+ else:
21
+ joints_to_draw = joints
22
+
23
+ if data is None:
24
+ df = mocap_track.values
25
+ else:
26
+ df = data
27
+
28
+ for joint in joints_to_draw:
29
+ ax.scatter(x=df['%s_Xposition'%joint][frame],
30
+ y=df['%s_Yposition'%joint][frame],
31
+ alpha=0.6, c='b', marker='o')
32
+
33
+ parent_x = df['%s_Xposition'%joint][frame]
34
+ parent_y = df['%s_Yposition'%joint][frame]
35
+
36
+ children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
37
+
38
+ for c in children_to_draw:
39
+ child_x = df['%s_Xposition'%c][frame]
40
+ child_y = df['%s_Yposition'%c][frame]
41
+ ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2)
42
+
43
+ if draw_names:
44
+ ax.annotate(joint,
45
+ (df['%s_Xposition'%joint][frame] + 0.1,
46
+ df['%s_Yposition'%joint][frame] + 0.1))
47
+
48
+ return ax
49
+
50
+ def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
51
+ from mpl_toolkits.mplot3d import Axes3D
52
+
53
+ if ax is None:
54
+ fig = plt.figure(figsize=figsize)
55
+ ax = fig.add_subplot(111, projection='3d')
56
+
57
+ if joints is None:
58
+ joints_to_draw = mocap_track.skeleton.keys()
59
+ else:
60
+ joints_to_draw = joints
61
+
62
+ if data is None:
63
+ df = mocap_track.values
64
+ else:
65
+ df = data
66
+
67
+ for joint in joints_to_draw:
68
+ parent_x = df['%s_Xposition'%joint][frame]
69
+ parent_y = df['%s_Zposition'%joint][frame]
70
+ parent_z = df['%s_Yposition'%joint][frame]
71
+ # ^ In mocaps, Y is the up-right axis
72
+
73
+ ax.scatter(xs=parent_x,
74
+ ys=parent_y,
75
+ zs=parent_z,
76
+ alpha=0.6, c='b', marker='o')
77
+
78
+
79
+ children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
80
+
81
+ for c in children_to_draw:
82
+ child_x = df['%s_Xposition'%c][frame]
83
+ child_y = df['%s_Zposition'%c][frame]
84
+ child_z = df['%s_Yposition'%c][frame]
85
+ # ^ In mocaps, Y is the up-right axis
86
+
87
+ ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black')
88
+
89
+ if draw_names:
90
+ ax.text(x=parent_x + 0.1,
91
+ y=parent_y + 0.1,
92
+ z=parent_z + 0.1,
93
+ s=joint,
94
+ color='rgba(0,0,0,0.9)')
95
+
96
+ return ax
97
+
98
+
99
+ def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)):
100
+ if ax is None:
101
+ fig = plt.figure(figsize=figsize)
102
+ ax = fig.add_subplot(111)
103
+
104
+ if data is None:
105
+ data = mocap_track.values
106
+
107
+ for frame in range(0, data.shape[0], 4):
108
+ # draw_stickfigure(mocap_track, f, data=data, ax=ax)
109
+
110
+ for joint in mocap_track.skeleton.keys():
111
+ children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
112
+
113
+ parent_x = data['%s_Xposition'%joint][frame]
114
+ parent_y = data['%s_Yposition'%joint][frame]
115
+
116
+ frame_alpha = frame/data.shape[0]
117
+
118
+ for c in children_to_draw:
119
+ child_x = data['%s_Xposition'%c][frame]
120
+ child_y = data['%s_Yposition'%c][frame]
121
+
122
+ ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
123
+
124
+
125
+
126
+ def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25):
127
+ fig = plt.figure(figsize=(16,4))
128
+ ax = plt.subplot2grid((1,8),(0,0))
129
+ ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest')
130
+
131
+ ax = plt.subplot2grid((1,8),(0,1), colspan=7)
132
+ for frame in range(feature_to_viz.shape[0]):
133
+ frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2
134
+
135
+ for joint_i, joint in enumerate(mocap_track.skeleton.keys()):
136
+ children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
137
+
138
+ parent_x = data['%s_Xposition'%joint][frame] + frame * gap
139
+ parent_y = data['%s_Yposition'%joint][frame]
140
+
141
+ ax.scatter(x=parent_x,
142
+ y=parent_y,
143
+ alpha=0.6,
144
+ cmap='RdBu',
145
+ c=feature_to_viz[frame][joint_i] * 10000,
146
+ marker='o',
147
+ s = abs(feature_to_viz[frame][joint_i] * 10000))
148
+ plt.axis('off')
149
+ for c in children_to_draw:
150
+ child_x = data['%s_Xposition'%c][frame] + frame * gap
151
+ child_y = data['%s_Yposition'%c][frame]
152
+
153
+ ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
154
+
155
+
156
+ def print_skel(X):
157
+ stack = [X.root_name]
158
+ tab=0
159
+ while stack:
160
+ joint = stack.pop()
161
+ tab = len(stack)
162
+ print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent']))
163
+ for c in X.skeleton[joint]['children']:
164
+ stack.append(c)
165
+
166
+
167
+ def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'):
168
+ if mf == 'bvh':
169
+ bw = BVHWriter()
170
+ with open('test.bvh', 'w') as ofile:
171
+ bw.write(mocap, ofile)
172
+
173
+ filepath = '../notebooks/test.bvh'
174
+ elif mf == 'pos':
175
+ c = list(mocap.values.columns)
176
+
177
+ for cc in c:
178
+ if 'rotation' in cc:
179
+ c.remove(cc)
180
+ mocap.values.to_csv('test.csv', index=False, columns=c)
181
+
182
+ filepath = '../notebooks/test.csv'
183
+ else:
184
+ return
185
+
186
+ url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time)
187
+ iframe = '<iframe src=' + url + ' width="100%" height=500></iframe>'
188
+ link = '<a href=%s target="_blank">New Window</a>'%url
189
+ return IPython.display.HTML(iframe+link)
190
+
191
+ def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None):
192
+ data_template = 'var dataBuffer = `$$DATA$$`;'
193
+ data_template += 'var metadata = $$META$$;'
194
+ data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);'
195
+ dir_path = os.path.dirname(os.path.realpath(__file__))
196
+
197
+
198
+ if base_url is None:
199
+ base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html')
200
+
201
+ # print(dir_path)
202
+
203
+ if mf == 'bvh':
204
+ pass
205
+ elif mf == 'pos':
206
+ cols = list(mocap.values.columns)
207
+ for c in cols:
208
+ if 'rotation' in c:
209
+ cols.remove(c)
210
+
211
+ data_csv = mocap.values.to_csv(index=False, columns=cols)
212
+
213
+ if meta is not None:
214
+ lines = [','.join(item) for item in meta.astype('str')]
215
+ meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']'
216
+ else:
217
+ meta_csv = '[]'
218
+
219
+ data_assigned = data_template.replace('$$DATA$$', data_csv)
220
+ data_assigned = data_assigned.replace('$$META$$', meta_csv)
221
+ data_assigned = data_assigned.replace('$$CZ$$', str(camera_z))
222
+ data_assigned = data_assigned.replace('$$SCALE$$', str(scale))
223
+ data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time))
224
+
225
+ else:
226
+ return
227
+
228
+
229
+
230
+ with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile:
231
+ oFile.write(data_assigned)
232
+
233
+ url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale)
234
+ iframe = '<iframe frameborder="0" src=' + url + ' width="100%" height=500></iframe>'
235
+ link = '<a href=%s target="_blank">New Window</a>'%url
236
+ return IPython.display.HTML(iframe+link)
dataloaders/pymo/writers.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ class BVHWriter():
5
+ def __init__(self):
6
+ pass
7
+
8
+ def write(self, X, ofile):
9
+
10
+ # Writing the skeleton info
11
+ ofile.write('HIERARCHY\n')
12
+
13
+ self.motions_ = []
14
+ self._printJoint(X, X.root_name, 0, ofile)
15
+
16
+ # Writing the motion header
17
+ ofile.write('MOTION\n')
18
+ ofile.write('Frames: %d\n'%X.values.shape[0])
19
+ ofile.write('Frame Time: %f\n'%X.framerate)
20
+
21
+ # Writing the data
22
+ self.motions_ = np.asarray(self.motions_).T
23
+ lines = [" ".join(item) for item in self.motions_.astype(str)]
24
+ ofile.write("".join("%s\n"%l for l in lines))
25
+
26
+ def _printJoint(self, X, joint, tab, ofile):
27
+
28
+ if X.skeleton[joint]['parent'] == None:
29
+ ofile.write('ROOT %s\n'%joint)
30
+ elif len(X.skeleton[joint]['children']) > 0:
31
+ ofile.write('%sJOINT %s\n'%('\t'*(tab), joint))
32
+ else:
33
+ ofile.write('%sEnd site\n'%('\t'*(tab)))
34
+
35
+ ofile.write('%s{\n'%('\t'*(tab)))
36
+
37
+ ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1),
38
+ X.skeleton[joint]['offsets'][0],
39
+ X.skeleton[joint]['offsets'][1],
40
+ X.skeleton[joint]['offsets'][2]))
41
+ channels = X.skeleton[joint]['channels']
42
+ n_channels = len(channels)
43
+
44
+ if n_channels > 0:
45
+ for ch in channels:
46
+ self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values))
47
+
48
+ if len(X.skeleton[joint]['children']) > 0:
49
+ ch_str = ''.join(' %s'*n_channels%tuple(channels))
50
+ ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str))
51
+
52
+ for c in X.skeleton[joint]['children']:
53
+ self._printJoint(X, c, tab+1, ofile)
54
+
55
+ ofile.write('%s}\n'%('\t'*(tab)))
dataloaders/utils/__pycache__/audio_features.cpython-312.pyc ADDED
Binary file (4.64 kB). View file
 
dataloaders/utils/__pycache__/other_tools.cpython-312.pyc ADDED
Binary file (37.2 kB). View file
 
dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc ADDED
Binary file (22.2 kB). View file
 
dataloaders/utils/audio_features.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """modified from https://github.com/yesheng-THU/GFGE/blob/main/data_processing/audio_features.py"""
2
+ import numpy as np
3
+ import librosa
4
+ import math
5
+ import os
6
+ import scipy.io.wavfile as wav
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from tqdm import tqdm
12
+ from typing import Optional, Tuple
13
+ from numpy.lib import stride_tricks
14
+ from loguru import logger
15
+
16
+ # Import Wav2Vec2Model to make it available for other modules
17
+ import sys
18
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
19
+ from models.utils.wav2vec import Wav2Vec2Model
20
+
21
+
22
+
23
+ def process_audio_data(audio_file, args, data, f_name, selected_file):
24
+ """Process audio data with support for different representations."""
25
+ logger.info(f"# ---- Building cache for Audio {f_name} ---- #")
26
+
27
+ if not os.path.exists(audio_file):
28
+ logger.warning(f"# ---- file not found for Audio {f_name}, skip all files with the same id ---- #")
29
+ selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True)
30
+ return None
31
+
32
+ audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy")
33
+
34
+ if args.audio_rep == "onset+amplitude" and os.path.exists(audio_save_path):
35
+ data['audio'] = np.load(audio_save_path)
36
+ logger.warning(f"# ---- file found cache for Audio {f_name} ---- #")
37
+
38
+ elif args.audio_rep == "onset+amplitude":
39
+ data['audio'] = calculate_onset_amplitude(audio_file, args.audio_sr, audio_save_path)
40
+
41
+ elif args.audio_rep == "mfcc":
42
+ audio_data, _ = librosa.load(audio_file)
43
+ data['audio'] = librosa.feature.melspectrogram(
44
+ y=audio_data,
45
+ sr=args.audio_sr,
46
+ n_mels=128,
47
+ hop_length=int(args.audio_sr/args.audio_fps)
48
+ ).transpose(1, 0)
49
+
50
+ if args.audio_norm and args.audio_rep == "wave16k":
51
+ data['audio'] = (data['audio'] - args.mean_audio) / args.std_audio
52
+
53
+ return data
54
+
55
+ def calculate_onset_amplitude(audio_file, audio_sr, save_path):
56
+ """Calculate onset and amplitude features from audio file."""
57
+ audio_data, sr = librosa.load(audio_file)
58
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=audio_sr)
59
+
60
+ # Calculate amplitude envelope
61
+ frame_length = 1024
62
+ shape = (audio_data.shape[-1] - frame_length + 1, frame_length)
63
+ strides = (audio_data.strides[-1], audio_data.strides[-1])
64
+ rolling_view = stride_tricks.as_strided(audio_data, shape=shape, strides=strides)
65
+ amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
66
+ amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
67
+
68
+ # Calculate onset
69
+ audio_onset_f = librosa.onset.onset_detect(y=audio_data, sr=audio_sr, units='frames')
70
+ onset_array = np.zeros(len(audio_data), dtype=float)
71
+ onset_array[audio_onset_f] = 1.0
72
+
73
+ # Combine features
74
+ features = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
75
+
76
+ # Save features
77
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
78
+ np.save(save_path, features)
79
+
80
+ return features
dataloaders/utils/data_sample.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from collections import defaultdict
4
+ from loguru import logger
5
+
6
+ def sample_from_clip(
7
+ lmdb_manager, audio_file, audio_each_file, pose_each_file, trans_each_file,
8
+ trans_v_each_file, shape_each_file, facial_each_file, word_each_file,
9
+ vid_each_file, emo_each_file, sem_each_file, args, ori_stride, ori_length,
10
+ disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
11
+ n_out_samples):
12
+ """Sample clips from the data according to specified parameters."""
13
+
14
+ round_seconds_skeleton = pose_each_file.shape[0] // args.pose_fps
15
+
16
+ # Calculate timing information
17
+ timing_info = calculate_timing_info(
18
+ audio_each_file, facial_each_file, round_seconds_skeleton,
19
+ args.audio_fps, args.pose_fps, args.audio_sr, args.audio_rep
20
+ )
21
+
22
+ round_seconds_skeleton = timing_info['final_seconds']
23
+
24
+ # Calculate clip boundaries
25
+ clip_info = calculate_clip_boundaries(
26
+ round_seconds_skeleton, clean_first_seconds, clean_final_seconds,
27
+ args.audio_fps, args.pose_fps
28
+ )
29
+
30
+ n_filtered_out = defaultdict(int)
31
+
32
+ # Process each training length ratio
33
+ for ratio in args.multi_length_training:
34
+ processed_data = process_data_with_ratio(
35
+ ori_stride, ori_length, ratio, clip_info, args, is_test,
36
+ audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,
37
+ shape_each_file, facial_each_file, word_each_file, vid_each_file,
38
+ emo_each_file, sem_each_file, audio_file,
39
+ lmdb_manager, n_out_samples
40
+ )
41
+
42
+ for type_key, count in processed_data['filtered_counts'].items():
43
+ n_filtered_out[type_key] += count
44
+
45
+ n_out_samples = processed_data['n_out_samples']
46
+
47
+ return n_filtered_out, n_out_samples
48
+
49
+ def calculate_timing_info(audio_data, facial_data, round_seconds_skeleton,
50
+ audio_fps, pose_fps, audio_sr, audio_rep):
51
+ """Calculate timing information for the data."""
52
+ if audio_data is not None:
53
+ if audio_rep != "wave16k":
54
+ round_seconds_audio = len(audio_data) // audio_fps
55
+ elif audio_rep == "mfcc":
56
+ round_seconds_audio = audio_data.shape[0] // audio_fps
57
+ else:
58
+ round_seconds_audio = audio_data.shape[0] // audio_sr
59
+
60
+ if facial_data is not None:
61
+ round_seconds_facial = facial_data.shape[0] // pose_fps
62
+ logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
63
+ final_seconds = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
64
+ max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
65
+ if final_seconds != max_round:
66
+ logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s")
67
+ else:
68
+ logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
69
+ final_seconds = min(round_seconds_audio, round_seconds_skeleton)
70
+ max_round = max(round_seconds_audio, round_seconds_skeleton)
71
+ if final_seconds != max_round:
72
+ logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s")
73
+ else:
74
+ final_seconds = round_seconds_skeleton
75
+
76
+ return {
77
+ 'final_seconds': final_seconds
78
+ }
79
+
80
+ def calculate_clip_boundaries(round_seconds, clean_first_seconds, clean_final_seconds,
81
+ audio_fps, pose_fps):
82
+ """Calculate the boundaries for clip sampling."""
83
+ clip_s_t = clean_first_seconds
84
+ clip_e_t = round_seconds - clean_final_seconds
85
+
86
+ return {
87
+ 'clip_s_t': clip_s_t,
88
+ 'clip_e_t': clip_e_t,
89
+ 'clip_s_f_audio': audio_fps * clip_s_t,
90
+ 'clip_e_f_audio': clip_e_t * audio_fps,
91
+ 'clip_s_f_pose': clip_s_t * pose_fps,
92
+ 'clip_e_f_pose': clip_e_t * pose_fps
93
+ }
94
+
95
+ def process_data_with_ratio(ori_stride, ori_length, ratio, clip_info, args, is_test,
96
+ audio_data, pose_data, trans_data, trans_v_data,
97
+ shape_data, facial_data, word_data, vid_data,
98
+ emo_data, sem_data, audio_file,
99
+ lmdb_manager, n_out_samples):
100
+ """Process data with a specific training length ratio."""
101
+
102
+ if is_test and not args.test_clip:
103
+ cut_length = clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose']
104
+ args.stride = cut_length
105
+ max_length = cut_length
106
+ else:
107
+ args.stride = int(ratio * ori_stride)
108
+ cut_length = int(ori_length * ratio)
109
+
110
+ num_subdivision = math.floor(
111
+ (clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose'] - cut_length) / args.stride
112
+ ) + 1
113
+
114
+ logger.info(f"pose from frame {clip_info['clip_s_f_pose']} to {clip_info['clip_e_f_pose']}, length {cut_length}")
115
+ logger.info(f"{num_subdivision} clips is expected with stride {args.stride}")
116
+
117
+ if audio_data is not None:
118
+ audio_short_length = math.floor(cut_length / args.pose_fps * args.audio_fps)
119
+ logger.info(f"audio from frame {clip_info['clip_s_f_audio']} to {clip_info['clip_e_f_audio']}, length {audio_short_length}")
120
+
121
+ # Process subdivisions
122
+ filtered_counts = defaultdict(int)
123
+ for i in range(num_subdivision):
124
+ sample_data = extract_sample_data(
125
+ i, clip_info, cut_length, args,
126
+ audio_data, pose_data, trans_data, trans_v_data,
127
+ shape_data, facial_data, word_data, vid_data,
128
+ emo_data, sem_data, audio_file,
129
+ audio_short_length
130
+ )
131
+
132
+ if sample_data['pose'].any() is not None:
133
+ lmdb_manager.add_sample([
134
+ sample_data['pose'], sample_data['audio'], sample_data['facial'],
135
+ sample_data['shape'], sample_data['word'], sample_data['emo'],
136
+ sample_data['sem'], sample_data['vid'], sample_data['trans'],
137
+ sample_data['trans_v'], sample_data['audio_name']
138
+ ])
139
+ n_out_samples += 1
140
+
141
+ return {
142
+ 'filtered_counts': filtered_counts,
143
+ 'n_out_samples': n_out_samples
144
+ }
145
+
146
+ def extract_sample_data(idx, clip_info, cut_length, args,
147
+ audio_data, pose_data, trans_data, trans_v_data,
148
+ shape_data, facial_data, word_data, vid_data,
149
+ emo_data, sem_data, audio_file,
150
+ audio_short_length):
151
+ """Extract a single sample from the data."""
152
+ start_idx = clip_info['clip_s_f_pose'] + idx * args.stride
153
+ fin_idx = start_idx + cut_length
154
+
155
+ sample_data = {
156
+ 'pose': pose_data[start_idx:fin_idx],
157
+ 'trans': trans_data[start_idx:fin_idx],
158
+ 'trans_v': trans_v_data[start_idx:fin_idx],
159
+ 'shape': shape_data[start_idx:fin_idx],
160
+ 'facial': facial_data[start_idx:fin_idx] if args.facial_rep is not None else np.array([-1]),
161
+ 'word': word_data[start_idx:fin_idx] if args.word_rep is not None else np.array([-1]),
162
+ 'emo': emo_data[start_idx:fin_idx] if args.emo_rep is not None else np.array([-1]),
163
+ 'sem': sem_data[start_idx:fin_idx] if args.sem_rep is not None else np.array([-1]),
164
+ 'vid': vid_data[start_idx:fin_idx] if args.id_rep is not None else np.array([-1]),
165
+ 'audio_name': audio_file
166
+ }
167
+
168
+ if audio_data is not None:
169
+ audio_start = clip_info['clip_s_f_audio'] + math.floor(idx * args.stride * args.audio_fps / args.pose_fps)
170
+ audio_end = audio_start + audio_short_length
171
+ sample_data['audio'] = audio_data[audio_start:audio_end]
172
+ else:
173
+ sample_data['audio'] = np.array([-1])
174
+
175
+ return sample_data
dataloaders/utils/mis_features.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # semantic_utils.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ from loguru import logger
5
+ import os
6
+
7
+ def process_semantic_data(sem_file, args, data, f_name):
8
+ """Process semantic representation data."""
9
+ logger.info(f"# ---- Building cache for Semantic {f_name} ---- #")
10
+
11
+ if not os.path.exists(sem_file):
12
+ logger.warning(f"# ---- file not found for Semantic {f_name} ---- #")
13
+ return None
14
+
15
+ sem_all = pd.read_csv(sem_file,
16
+ sep='\t',
17
+ names=["name", "start_time", "end_time", "duration", "score", "keywords"])
18
+
19
+ sem_data = []
20
+ for i in range(data['pose'].shape[0]):
21
+ current_time = i/args.pose_fps
22
+ found_score = False
23
+
24
+ for _, row in sem_all.iterrows():
25
+ if row['start_time'] <= current_time <= row['end_time']:
26
+ sem_data.append(row['score'])
27
+ found_score = True
28
+ break
29
+
30
+ if not found_score:
31
+ sem_data.append(0.0)
32
+
33
+ data['sem'] = np.array(sem_data)
34
+ return data
35
+
36
+ def process_emotion_data(f_name, data, args):
37
+ """Process emotion representation data."""
38
+ logger.info(f"# ---- Building cache for Emotion {f_name} ---- #")
39
+
40
+ rtype, start = int(f_name.split('_')[3]), int(f_name.split('_')[3])
41
+ if rtype in [0, 2, 4, 6]:
42
+ if 1 <= start <= 64:
43
+ score = 0
44
+ elif 65 <= start <= 72:
45
+ score = 1
46
+ elif 73 <= start <= 80:
47
+ score = 2
48
+ elif 81 <= start <= 86:
49
+ score = 3
50
+ elif 87 <= start <= 94:
51
+ score = 4
52
+ elif 95 <= start <= 102:
53
+ score = 5
54
+ elif 103 <= start <= 110:
55
+ score = 6
56
+ elif 111 <= start <= 118:
57
+ score = 7
58
+ else:
59
+ score = 0
60
+ else:
61
+ score = 0
62
+
63
+ data['emo'] = np.repeat(np.array(score).reshape(1, 1), data['pose'].shape[0], axis=0)
64
+ return data
dataloaders/utils/motion_rep_transfer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import smplx
2
+ import torch
3
+ import numpy as np
4
+ from . import rotation_conversions as rc
5
+ import os
6
+ import wget
7
+
8
+ download_path = "./datasets/hub"
9
+ smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
10
+ if not os.path.exists(smplx_model_dir):
11
+ smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
12
+ os.makedirs(smplx_model_dir, exist_ok=True)
13
+ if not os.path.exists(smplx_model_file_path):
14
+ print(f"Downloading {smplx_model_file_path}")
15
+ wget.download(
16
+ "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
17
+ smplx_model_file_path,
18
+ )
19
+
20
+ smplx_model = smplx.create(
21
+ "./datasets/hub/smplx_models/",
22
+ model_type='smplx',
23
+ gender='NEUTRAL_2020',
24
+ use_face_contour=False,
25
+ num_betas=300,
26
+ num_expression_coeffs=100,
27
+ ext='npz',
28
+ use_pca=False,
29
+ ).eval()
30
+
31
+ def get_motion_rep_tensor(motion_tensor, pose_fps=30, device="cuda", betas=None):
32
+ global smplx_model
33
+ smplx_model = smplx_model.to(device)
34
+ bs, n, _ = motion_tensor.shape
35
+ motion_tensor = motion_tensor.float().to(device)
36
+ motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
37
+ betas = torch.zeros(n, 300, device=device) if betas is None else betas.to(device).unsqueeze(0).repeat(n, 1)
38
+ output = smplx_model(
39
+ betas=torch.zeros(bs * n, 300, device=device),
40
+ transl=torch.zeros(bs * n, 3, device=device),
41
+ expression=torch.zeros(bs * n, 100, device=device),
42
+ jaw_pose=torch.zeros(bs * n, 3, device=device),
43
+ global_orient=torch.zeros(bs * n, 3, device=device),
44
+ body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
45
+ left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
46
+ right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
47
+ return_joints=True,
48
+ leye_pose=torch.zeros(bs * n, 3, device=device),
49
+ reye_pose=torch.zeros(bs * n, 3, device=device),
50
+ )
51
+ joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
52
+ dt = 1 / pose_fps
53
+ init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
54
+ middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
55
+ final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
56
+ vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
57
+ position = joints
58
+ rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
59
+ rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)
60
+ init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
61
+ middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
62
+ final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
63
+ angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)
64
+ rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
65
+ return {
66
+ "position": position,
67
+ "velocity": vel,
68
+ "rotation": rot6d,
69
+ "axis_angle": motion_tensor,
70
+ "angular_velocity": angular_velocity,
71
+ "rep15d": rep15d,
72
+ }
73
+
74
+ def get_motion_rep_numpy(poses_np, pose_fps=30, device="cuda", expressions=None, expression_only=False, betas=None):
75
+ # motion["poses"] is expected to be numpy array of shape (n, 165)
76
+ # (n, 55*3), axis-angle for 55 joints
77
+ global smplx_model
78
+ smplx_model = smplx_model.to(device)
79
+ n = poses_np.shape[0]
80
+
81
+ # Convert numpy to torch tensor for SMPL-X forward pass
82
+ poses_ts = torch.from_numpy(poses_np).float().to(device).unsqueeze(0) # (1, n, 165)
83
+ poses_ts_reshaped = poses_ts.reshape(-1, 165) # (n, 165)
84
+ betas = torch.zeros(n, 300, device=device) if betas is None else torch.from_numpy(betas).to(device).unsqueeze(0).repeat(n, 1)
85
+ if expressions is not None and expression_only:
86
+ # print("xx")
87
+ expressions = torch.from_numpy(expressions).float().to(device)
88
+ output = smplx_model(
89
+ betas=betas,
90
+ transl=torch.zeros(n, 3, device=device),
91
+ expression=expressions,
92
+ jaw_pose=poses_ts_reshaped[:, 22 * 3:23 * 3],
93
+ global_orient=torch.zeros(n, 3, device=device),
94
+ body_pose=torch.zeros(n, 21*3, device=device),
95
+ left_hand_pose=torch.zeros(n, 15*3, device=device),
96
+ right_hand_pose=torch.zeros(n, 15*3, device=device),
97
+ return_joints=True,
98
+ leye_pose=torch.zeros(n, 3, device=device),
99
+ reye_pose=torch.zeros(n, 3, device=device),
100
+ )
101
+ joints = output["vertices"].detach().cpu().numpy().reshape(n, -1)
102
+ return {"vertices": joints}
103
+
104
+ # Run smplx model to get joints
105
+ output = smplx_model(
106
+ betas=betas,
107
+ transl=torch.zeros(n, 3, device=device),
108
+ expression=torch.zeros(n, 100, device=device),
109
+ jaw_pose=torch.zeros(n, 3, device=device),
110
+ global_orient=torch.zeros(n, 3, device=device),
111
+ body_pose=poses_ts_reshaped[:, 3:21 * 3 + 3],
112
+ left_hand_pose=poses_ts_reshaped[:, 25 * 3:40 * 3],
113
+ right_hand_pose=poses_ts_reshaped[:, 40 * 3:55 * 3],
114
+ return_joints=True,
115
+ leye_pose=torch.zeros(n, 3, device=device),
116
+ reye_pose=torch.zeros(n, 3, device=device),
117
+ )
118
+ joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
119
+
120
+ dt = 1 / pose_fps
121
+ # Compute linear velocity
122
+ init_vel = (joints[1:2] - joints[0:1]) / dt
123
+ middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
124
+ final_vel = (joints[-1:] - joints[-2:-1]) / dt
125
+ vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
126
+
127
+ position = joints
128
+
129
+ # Compute rotation 6D from axis-angle
130
+ poses_ts_reshaped_aa = poses_ts.reshape(1, n, 55, 3)
131
+ rot_matrices = rc.axis_angle_to_matrix(poses_ts_reshaped_aa)[0] # (n, 55, 3, 3)
132
+ rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()
133
+
134
+ # Compute angular velocity
135
+ init_vel_ang = (poses_np[1:2] - poses_np[0:1]) / dt
136
+ middle_vel_ang = (poses_np[2:] - poses_np[:-2]) / (2 * dt)
137
+ final_vel_ang = (poses_np[-1:] - poses_np[-2:-1]) / dt
138
+ angular_velocity = np.concatenate([init_vel_ang, middle_vel_ang, final_vel_ang], axis=0).reshape(n, 55, 3)
139
+
140
+ # rep15d: position(55*3), vel(55*3), rot6d(55*6), angular_velocity(55*3) => total 55*(3+3+6+3)=55*15
141
+ rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15)
142
+
143
+ return {
144
+ "position": position,
145
+ "velocity": vel,
146
+ "rotation": rot6d,
147
+ "axis_angle": poses_np,
148
+ "angular_velocity": angular_velocity,
149
+ "rep15d": rep15d,
150
+ }
151
+
152
+ def process_smplx_motion(pose_file, smplx_model, pose_fps, facial_rep=None):
153
+ """Process SMPLX pose and facial data together."""
154
+ pose_data = np.load(pose_file, allow_pickle=True)
155
+ stride = int(30/pose_fps)
156
+
157
+ # Extract pose and facial data with same stride
158
+ pose_frames = pose_data["poses"][::stride]
159
+ facial_frames = pose_data["expressions"][::stride] if facial_rep is not None else None
160
+
161
+ # Process translations
162
+ trans = pose_data["trans"][::stride]
163
+ trans[:,0] = trans[:,0] - trans[0,0]
164
+ trans[:,2] = trans[:,2] - trans[0,2]
165
+
166
+ # Calculate translation velocities
167
+ trans_v = np.zeros_like(trans)
168
+ trans_v[1:,0] = trans[1:,0] - trans[:-1,0]
169
+ trans_v[0,0] = trans_v[1,0]
170
+ trans_v[1:,2] = trans[1:,2] - trans[:-1,2]
171
+ trans_v[0,2] = trans_v[1,2]
172
+ trans_v[:,1] = trans[:,1]
173
+
174
+ # Process shape data
175
+ shape = np.repeat(pose_data["betas"].reshape(1, 300), pose_frames.shape[0], axis=0)
176
+
177
+ # # Calculate contacts
178
+ # contacts = calculate_foot_contacts(pose_data, smplx_model)
179
+
180
+ # if contacts is not None:
181
+ # pose_data = np.concatenate([pose_data, contacts], axis=1)
182
+
183
+ return {
184
+ 'pose': pose_frames,
185
+ 'trans': trans,
186
+ 'trans_v': trans_v,
187
+ 'shape': shape,
188
+ 'facial': facial_frames if facial_frames is not None else np.array([-1])
189
+ }
190
+
191
+ def calculate_foot_contacts(pose_data, smplx_model):
192
+ """Calculate foot contacts from pose data."""
193
+ max_length = 128
194
+ all_tensor = []
195
+ n = pose_data["poses"].shape[0]
196
+
197
+ # Process in batches
198
+ for i in range(n // max_length):
199
+ joints = process_joints_batch(pose_data, i, max_length, smplx_model)
200
+ all_tensor.append(joints)
201
+
202
+ # Process remaining frames
203
+ if n % max_length != 0:
204
+ r = n % max_length
205
+ joints = process_joints_batch(pose_data, n // max_length, r, smplx_model, remainder=True)
206
+ all_tensor.append(joints)
207
+
208
+ # Calculate velocities and contacts
209
+ joints = torch.cat(all_tensor, axis=0)
210
+ feetv = torch.zeros(joints.shape[1], joints.shape[0])
211
+ joints = joints.permute(1, 0, 2)
212
+ feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1)
213
+ contacts = (feetv < 0.01).numpy().astype(float)
214
+
215
+ return contacts.transpose(1, 0)
216
+
217
+ def process_joints_batch(pose_data, batch_idx, batch_size, smplx_model, remainder=False):
218
+ """Process a batch of joints for contact calculation."""
219
+ start_idx = batch_idx * batch_size
220
+ end_idx = start_idx + batch_size
221
+
222
+ with torch.no_grad():
223
+ return smplx_model(
224
+ betas=torch.from_numpy(pose_data["betas"]).cuda().float().repeat(batch_size, 1),
225
+ transl=torch.from_numpy(pose_data["trans"][start_idx:end_idx]).cuda().float(),
226
+ expression=torch.from_numpy(pose_data["expressions"][start_idx:end_idx]).cuda().float(),
227
+ jaw_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 66:69]).cuda().float(),
228
+ global_orient=torch.from_numpy(pose_data["poses"][start_idx:end_idx, :3]).cuda().float(),
229
+ body_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 3:21*3+3]).cuda().float(),
230
+ left_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 25*3:40*3]).cuda().float(),
231
+ right_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 40*3:55*3]).cuda().float(),
232
+ leye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 69:72]).cuda().float(),
233
+ reye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 72:75]).cuda().float(),
234
+ return_verts=True,
235
+ return_joints=True
236
+ )['joints'][:, (7,8,10,11), :].reshape(batch_size, 4, 3).cpu()
dataloaders/utils/other_tools.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import shutil
6
+ import csv
7
+ import pprint
8
+ import pandas as pd
9
+ from loguru import logger
10
+ from collections import OrderedDict
11
+ import matplotlib.pyplot as plt
12
+ import pickle
13
+ import time
14
+ import lmdb
15
+ import numpy as np
16
+
17
+ def adjust_array(x, k):
18
+ len_x = len(x)
19
+ len_k = len(k)
20
+
21
+ # If x is shorter than k, pad with zeros
22
+ if len_x < len_k:
23
+ return np.pad(x, (0, len_k - len_x), 'constant')
24
+
25
+ # If x is longer than k, truncate x
26
+ elif len_x > len_k:
27
+ return x[:len_k]
28
+
29
+ # If both are of same length
30
+ else:
31
+ return x
32
+
33
+ def onset_to_frame(onset_times, audio_length, fps):
34
+ # Calculate total number of frames for the given audio length
35
+ total_frames = int(audio_length * fps)
36
+
37
+ # Create an array of zeros of shape (total_frames,)
38
+ frame_array = np.zeros(total_frames, dtype=np.int32)
39
+
40
+ # For each onset time, calculate the frame number and set it to 1
41
+ for onset in onset_times:
42
+ frame_num = int(onset * fps)
43
+ # Check if the frame number is within the array bounds
44
+ if 0 <= frame_num < total_frames:
45
+ frame_array[frame_num] = 1
46
+
47
+ return frame_array
48
+
49
+ def smooth_animations(animation1, animation2, blend_frames):
50
+ """
51
+ Smoothly transition between two animation clips using linear interpolation.
52
+
53
+ Parameters:
54
+ - animation1: The first animation clip, a numpy array of shape [n, k].
55
+ - animation2: The second animation clip, a numpy array of shape [n, k].
56
+ - blend_frames: Number of frames over which to blend the two animations.
57
+
58
+ Returns:
59
+ - A smoothly blended animation clip of shape [2n, k].
60
+ """
61
+
62
+ # Ensure blend_frames doesn't exceed the length of either animation
63
+ blend_frames = min(blend_frames, len(animation1), len(animation2))
64
+
65
+ # Extract overlapping sections
66
+ overlap_a1 = animation1[-blend_frames:-blend_frames+1, :]
67
+ overlap_a2 = animation2[blend_frames-1:blend_frames, :]
68
+
69
+ # Create blend weights for linear interpolation
70
+ alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1)
71
+
72
+ # Linearly interpolate between overlapping sections
73
+ blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha
74
+
75
+ # Extend the animations to form the result with 2n frames
76
+ if blend_frames == len(animation1) and blend_frames == len(animation2):
77
+ result = blended_overlap
78
+ else:
79
+ before_blend = animation1[:-blend_frames]
80
+ after_blend = animation2[blend_frames:]
81
+ result = np.vstack((before_blend, blended_overlap, after_blend))
82
+ return result
83
+
84
+
85
+ def interpolate_sequence(quaternions):
86
+ bs, n, j, _ = quaternions.shape
87
+ new_n = 2 * n
88
+ new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype)
89
+
90
+ for i in range(n):
91
+ q1 = quaternions[:, i, :, :]
92
+ new_quaternions[:, 2*i, :, :] = q1
93
+
94
+ if i < n - 1:
95
+ q2 = quaternions[:, i + 1, :, :]
96
+ new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5)
97
+ else:
98
+ # For the last point, duplicate the value
99
+ new_quaternions[:, 2*i + 1, :, :] = q1
100
+
101
+ return new_quaternions
102
+
103
+ def quaternion_multiply(q1, q2):
104
+ w1, x1, y1, z1 = q1
105
+ w2, x2, y2, z2 = q2
106
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
107
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
108
+ y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
109
+ z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
110
+ return w, x, y, z
111
+
112
+ def quaternion_conjugate(q):
113
+ w, x, y, z = q
114
+ return (w, -x, -y, -z)
115
+
116
+ def slerp(q1, q2, t):
117
+ dot = torch.sum(q1 * q2, dim=-1, keepdim=True)
118
+
119
+ flip = (dot < 0).float()
120
+ q2 = (1 - flip * 2) * q2
121
+ dot = dot * (1 - flip * 2)
122
+
123
+ DOT_THRESHOLD = 0.9995
124
+ mask = (dot > DOT_THRESHOLD).float()
125
+
126
+ theta_0 = torch.acos(dot)
127
+ theta = theta_0 * t
128
+
129
+ q3 = q2 - q1 * dot
130
+ q3 = q3 / torch.norm(q3, dim=-1, keepdim=True)
131
+
132
+ interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3)
133
+
134
+ return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated
135
+
136
+ def estimate_linear_velocity(data_seq, dt):
137
+ '''
138
+ Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates
139
+ the velocity for the middle T-2 steps using a second order central difference scheme.
140
+ The first and last frames are with forward and backward first-order
141
+ differences, respectively
142
+ - h : step size
143
+ '''
144
+ # first steps is forward diff (t+1 - t) / dt
145
+ init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt
146
+ # middle steps are second order (t+1 - t-1) / 2dt
147
+ middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt)
148
+ # last step is backward diff (t - t-1) / dt
149
+ final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt
150
+
151
+ vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1)
152
+ return vel_seq
153
+
154
+
155
+ def estimate_angular_velocity(rot_seq, dt):
156
+ '''
157
+ Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps.
158
+ Input sequence should be of shape (B, T, ..., 3, 3)
159
+ '''
160
+ # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix
161
+ dRdt = estimate_linear_velocity(rot_seq, dt)
162
+ R = rot_seq
163
+ RT = R.transpose(-1, -2)
164
+ # compute skew-symmetric angular velocity tensor
165
+ w_mat = torch.matmul(dRdt, RT)
166
+ # pull out angular velocity vector by averaging symmetric entries
167
+ w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0
168
+ w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0
169
+ w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0
170
+ w = torch.stack([w_x, w_y, w_z], axis=-1)
171
+ return w
172
+
173
+ import matplotlib.image as mpimg
174
+ from io import BytesIO
175
+
176
+ def image_from_bytes(image_bytes):
177
+ return mpimg.imread(BytesIO(image_bytes), format='PNG')
178
+
179
+
180
+
181
+ def process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1):
182
+ import matplotlib
183
+ matplotlib.use('Agg')
184
+ import matplotlib.pyplot as plt
185
+ import trimesh
186
+ import pyvirtualdisplay as Display
187
+
188
+ vertices = vertices_all[i]
189
+ vertices1 = vertices1_all[i]
190
+ filename = f"{output_dir}frame_{i}.png"
191
+ filenames.append(filename)
192
+ if i%100 == 0:
193
+ print('processed', i, 'frames')
194
+ #time_s = time.time()
195
+ #print(vertices.shape)
196
+ if use_matplotlib:
197
+ fig = plt.figure(figsize=(20, 10))
198
+ ax = fig.add_subplot(121, projection="3d")
199
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
200
+ #ax.view_init(elev=0, azim=90)
201
+ x = vertices[:, 0]
202
+ y = vertices[:, 1]
203
+ z = vertices[:, 2]
204
+ ax.scatter(x, y, z, s=0.5)
205
+ ax.set_xlim([-1.0, 1.0])
206
+ ax.set_ylim([-0.5, 1.5])#heigth
207
+ ax.set_zlim([-0, 2])#depth
208
+ ax.set_box_aspect((1,1,1))
209
+ else:
210
+ mesh = trimesh.Trimesh(vertices, faces)
211
+ scene = mesh.scene()
212
+ scene.camera.fov = camera_params['fov']
213
+ scene.camera.resolution = camera_params['resolution']
214
+ scene.camera.z_near = camera_params['z_near']
215
+ scene.camera.z_far = camera_params['z_far']
216
+ scene.graph[scene.camera.name] = camera_params['transform']
217
+ fig, ax =plt.subplots(1,2, figsize=(16, 6))
218
+ image = scene.save_image(resolution=[640, 480], visible=False)
219
+ im0 = ax[0].imshow(image_from_bytes(image))
220
+ ax[0].axis('off')
221
+
222
+ if use_matplotlib:
223
+ ax2 = fig.add_subplot(122, projection="3d")
224
+ ax2.set_box_aspect((1,1,1))
225
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
226
+ x1 = vertices1[:, 0]
227
+ y1 = vertices1[:, 1]
228
+ z1 = vertices1[:, 2]
229
+ ax2.scatter(x1, y1, z1, s=0.5)
230
+ ax2.set_xlim([-1.0, 1.0])
231
+ ax2.set_ylim([-0.5, 1.5])#heigth
232
+ ax2.set_zlim([-0, 2])
233
+ plt.savefig(filename, bbox_inches='tight')
234
+ plt.close(fig)
235
+ else:
236
+ mesh1 = trimesh.Trimesh(vertices1, faces)
237
+ scene1 = mesh1.scene()
238
+ scene1.camera.fov = camera_params1['fov']
239
+ scene1.camera.resolution = camera_params1['resolution']
240
+ scene1.camera.z_near = camera_params1['z_near']
241
+ scene1.camera.z_far = camera_params1['z_far']
242
+ scene1.graph[scene1.camera.name] = camera_params1['transform']
243
+ image1 = scene1.save_image(resolution=[640, 480], visible=False)
244
+ im1 = ax[1].imshow(image_from_bytes(image1))
245
+ ax[1].axis('off')
246
+ plt.savefig(filename, bbox_inches='tight')
247
+ plt.close(fig)
248
+
249
+ def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames):
250
+ import multiprocessing
251
+ import trimesh
252
+ num_cores = multiprocessing.cpu_count() # This will get the number of cores on your machine.
253
+ mesh = trimesh.Trimesh(vertices_all[0], faces)
254
+ scene = mesh.scene()
255
+ camera_params = {
256
+ 'fov': scene.camera.fov,
257
+ 'resolution': scene.camera.resolution,
258
+ 'focal': scene.camera.focal,
259
+ 'z_near': scene.camera.z_near,
260
+ "z_far": scene.camera.z_far,
261
+ 'transform': scene.graph[scene.camera.name][0]
262
+ }
263
+ mesh1 = trimesh.Trimesh(vertices1_all[0], faces)
264
+ scene1 = mesh1.scene()
265
+ camera_params1 = {
266
+ 'fov': scene1.camera.fov,
267
+ 'resolution': scene1.camera.resolution,
268
+ 'focal': scene1.camera.focal,
269
+ 'z_near': scene1.camera.z_near,
270
+ "z_far": scene1.camera.z_far,
271
+ 'transform': scene1.graph[scene1.camera.name][0]
272
+ }
273
+ # Use a Pool to manage the processes
274
+ # print(num_cores)
275
+ progress = multiprocessing.Value('i', 0)
276
+ lock = multiprocessing.Lock()
277
+ with multiprocessing.Pool(num_cores) as pool:
278
+ pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)])
279
+
280
+ def render_one_sequence(
281
+ res_npz_path,
282
+ gt_npz_path,
283
+ output_dir,
284
+ audio_path,
285
+ model_folder="/data/datasets/smplx_models/",
286
+ model_type='smplx',
287
+ gender='NEUTRAL_2020',
288
+ ext='npz',
289
+ num_betas=300,
290
+ num_expression_coeffs=100,
291
+ use_face_contour=False,
292
+ use_matplotlib=False,
293
+ args=None):
294
+ import smplx
295
+ import matplotlib.pyplot as plt
296
+ import imageio
297
+ from tqdm import tqdm
298
+ import os
299
+ import numpy as np
300
+ import torch
301
+ import moviepy.editor as mp
302
+ import librosa
303
+
304
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
305
+ model = smplx.create(
306
+ model_folder,
307
+ model_type=model_type,
308
+ gender=gender,
309
+ use_face_contour=use_face_contour,
310
+ num_betas=num_betas,
311
+ num_expression_coeffs=num_expression_coeffs,
312
+ ext=ext,
313
+ use_pca=False,
314
+ ).to(device)
315
+
316
+ #data_npz = np.load(f"{output_dir}{res_npz_path}.npz")
317
+ data_np_body = np.load(res_npz_path, allow_pickle=True)
318
+ gt_np_body = np.load(gt_npz_path, allow_pickle=True)
319
+
320
+ if not os.path.exists(output_dir): os.makedirs(output_dir)
321
+ filenames = []
322
+ if not use_matplotlib:
323
+ import trimesh
324
+ #import pyrender
325
+ from pyvirtualdisplay import Display
326
+ display = Display(visible=0, size=(640, 480))
327
+ display.start()
328
+ faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"]
329
+ seconds = 1
330
+ #data_npz["jaw_pose"].shape[0]
331
+ n = data_np_body["poses"].shape[0]
332
+ beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device)
333
+ beta = beta.repeat(n, 1)
334
+ expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device)
335
+ jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device)
336
+ pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device)
337
+ transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device)
338
+ # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape)
339
+ output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose,
340
+ global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3],
341
+ leye_pose=pose[:, 69:72],
342
+ reye_pose=pose[:, 72:75],
343
+ return_verts=True)
344
+ vertices_all = output["vertices"].cpu().detach().numpy()
345
+
346
+ beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device)
347
+ expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).to(device)
348
+ jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).to(device)
349
+ pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).to(device)
350
+ transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).to(device)
351
+ output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3],
352
+ leye_pose=pose1[:, 69:72],
353
+ reye_pose=pose1[:, 72:75],return_verts=True)
354
+ vertices1_all = output1["vertices"].cpu().detach().numpy()
355
+ if args.debug:
356
+ seconds = 1
357
+ else:
358
+ seconds = vertices_all.shape[0]//30
359
+ # camera_settings = None
360
+ time_s = time.time()
361
+ generate_images(int(seconds*30), vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames)
362
+ filenames = [f"{output_dir}frame_{i}.png" for i in range(int(seconds*30))]
363
+ # print(time.time()-time_s)
364
+ # for i in tqdm(range(seconds*30)):
365
+ # vertices = vertices_all[i]
366
+ # vertices1 = vertices1_all[i]
367
+ # filename = f"{output_dir}frame_{i}.png"
368
+ # filenames.append(filename)
369
+ # #time_s = time.time()
370
+ # #print(vertices.shape)
371
+ # if use_matplotlib:
372
+ # fig = plt.figure(figsize=(20, 10))
373
+ # ax = fig.add_subplot(121, projection="3d")
374
+ # fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
375
+ # #ax.view_init(elev=0, azim=90)
376
+ # x = vertices[:, 0]
377
+ # y = vertices[:, 1]
378
+ # z = vertices[:, 2]
379
+ # ax.scatter(x, y, z, s=0.5)
380
+ # ax.set_xlim([-1.0, 1.0])
381
+ # ax.set_ylim([-0.5, 1.5])#heigth
382
+ # ax.set_zlim([-0, 2])#depth
383
+ # ax.set_box_aspect((1,1,1))
384
+ # else:
385
+ # mesh = trimesh.Trimesh(vertices, faces)
386
+ # if i == 0:
387
+ # scene = mesh.scene()
388
+ # camera_params = {
389
+ # 'fov': scene.camera.fov,
390
+ # 'resolution': scene.camera.resolution,
391
+ # 'focal': scene.camera.focal,
392
+ # 'z_near': scene.camera.z_near,
393
+ # "z_far": scene.camera.z_far,
394
+ # 'transform': scene.graph[scene.camera.name][0]
395
+ # }
396
+ # else:
397
+ # scene = mesh.scene()
398
+ # scene.camera.fov = camera_params['fov']
399
+ # scene.camera.resolution = camera_params['resolution']
400
+ # scene.camera.z_near = camera_params['z_near']
401
+ # scene.camera.z_far = camera_params['z_far']
402
+ # scene.graph[scene.camera.name] = camera_params['transform']
403
+ # fig, ax =plt.subplots(1,2, figsize=(16, 6))
404
+ # image = scene.save_image(resolution=[640, 480], visible=False)
405
+ # #print((time.time()-time_s))
406
+ # im0 = ax[0].imshow(image_from_bytes(image))
407
+ # ax[0].axis('off')
408
+
409
+ # # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0)
410
+ # # expression1 = torch.from_numpy(gt_np_body["expressions"][i]).to(torch.float32).unsqueeze(0)
411
+ # # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][i][66:69]).to(torch.float32).unsqueeze(0)
412
+ # # pose1 = torch.from_numpy(gt_np_body["poses"][i]).to(torch.float32).unsqueeze(0)
413
+ # # transl1 = torch.from_numpy(gt_np_body["trans"][i]).to(torch.float32).unsqueeze(0)
414
+ # # #print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape)global_orient=pose[0:1,:3],
415
+ # # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[0:1,:3], body_pose=pose1[0:1,3:21*3+3], left_hand_pose=pose1[0:1,25*3:40*3], right_hand_pose=pose1[0:1,40*3:55*3], return_verts=True)
416
+ # # vertices1 = output1["vertices"].cpu().detach().numpy()[0]
417
+
418
+ # if use_matplotlib:
419
+ # ax2 = fig.add_subplot(122, projection="3d")
420
+ # ax2.set_box_aspect((1,1,1))
421
+ # fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
422
+ # #ax2.view_init(elev=0, azim=90)
423
+ # x1 = vertices1[:, 0]
424
+ # y1 = vertices1[:, 1]
425
+ # z1 = vertices1[:, 2]
426
+ # ax2.scatter(x1, y1, z1, s=0.5)
427
+ # ax2.set_xlim([-1.0, 1.0])
428
+ # ax2.set_ylim([-0.5, 1.5])#heigth
429
+ # ax2.set_zlim([-0, 2])
430
+ # plt.savefig(filename, bbox_inches='tight')
431
+ # plt.close(fig)
432
+ # else:
433
+ # mesh1 = trimesh.Trimesh(vertices1, faces)
434
+ # if i == 0:
435
+ # scene1 = mesh1.scene()
436
+ # camera_params1 = {
437
+ # 'fov': scene1.camera.fov,
438
+ # 'resolution': scene1.camera.resolution,
439
+ # 'focal': scene1.camera.focal,
440
+ # 'z_near': scene1.camera.z_near,
441
+ # "z_far": scene1.camera.z_far,
442
+ # 'transform': scene1.graph[scene1.camera.name][0]
443
+ # }
444
+ # else:
445
+ # scene1 = mesh1.scene()
446
+ # scene1.camera.fov = camera_params1['fov']
447
+ # scene1.camera.resolution = camera_params1['resolution']
448
+ # scene1.camera.z_near = camera_params1['z_near']
449
+ # scene1.camera.z_far = camera_params1['z_far']
450
+ # scene1.graph[scene1.camera.name] = camera_params1['transform']
451
+ # image1 = scene1.save_image(resolution=[640, 480], visible=False)
452
+ # im1 = ax[1].imshow(image_from_bytes(image1))
453
+ # ax[1].axis('off')
454
+ # plt.savefig(filename, bbox_inches='tight')
455
+ # plt.close(fig)
456
+
457
+ # display.stop()
458
+ # print(filenames)
459
+ images = [imageio.imread(filename) for filename in filenames]
460
+ imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=30)
461
+ for filename in filenames:
462
+ os.remove(filename)
463
+
464
+ video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4")
465
+ # audio, sr = librosa.load(audio_path)
466
+ # audio = audio[:seconds*sr]
467
+ # print(audio.shape, seconds, sr)
468
+ # import soundfile as sf
469
+ # sf.write(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, 16000, 'PCM_24')
470
+ # audio_tmp = librosa.output.write_wav(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, sr=16000)
471
+ audio = mp.AudioFileClip(audio_path)
472
+ if audio.duration > video.duration:
473
+ audio = audio.subclip(0, video.duration)
474
+ final_clip = video.set_audio(audio)
475
+ final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4")
476
+ os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4")
477
+
478
+ def print_exp_info(args):
479
+ logger.info(pprint.pformat(vars(args)))
480
+ logger.info(f"# ------------ {args.name} ----------- #")
481
+ logger.info("PyTorch version: {}".format(torch.__version__))
482
+ logger.info("CUDA version: {}".format(torch.version.cuda))
483
+ logger.info("{} GPUs".format(torch.cuda.device_count()))
484
+ logger.info(f"Random Seed: {args.random_seed}")
485
+
486
+ def args2csv(args, get_head=False, list4print=[]):
487
+ for k, v in args.items():
488
+ if isinstance(args[k], dict):
489
+ args2csv(args[k], get_head, list4print)
490
+ else: list4print.append(k) if get_head else list4print.append(v)
491
+ return list4print
492
+
493
+ class EpochTracker:
494
+ def __init__(self, metric_names, metric_directions):
495
+ assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length"
496
+
497
+
498
+ self.metric_names = metric_names
499
+ self.states = ['train', 'val', 'test']
500
+ self.types = ['last', 'best']
501
+
502
+
503
+ self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0}
504
+ for type_ in self.types}
505
+ for state in self.states}
506
+ for name, is_higher_better in zip(metric_names, metric_directions)}
507
+
508
+ self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}")
509
+ for state in self.states}
510
+ for name in metric_names}
511
+
512
+
513
+ self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)}
514
+ self.train_history = {name: [] for name in metric_names}
515
+ self.val_history = {name: [] for name in metric_names}
516
+
517
+
518
+ def update_meter(self, name, state, value):
519
+ self.loss_meters[name][state].update(value)
520
+
521
+
522
+ def update_values(self, name, state, epoch):
523
+ value_avg = self.loss_meters[name][state].avg
524
+ new_best = False
525
+
526
+
527
+ if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or
528
+ (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])):
529
+ self.values[name][state]['best']['value'] = value_avg
530
+ self.values[name][state]['best']['epoch'] = epoch
531
+ new_best = True
532
+ self.values[name][state]['last']['value'] = value_avg
533
+ self.values[name][state]['last']['epoch'] = epoch
534
+ return new_best
535
+
536
+
537
+ def get(self, name, state, type_):
538
+ return self.values[name][state][type_]
539
+
540
+
541
+ def reset(self):
542
+ for name in self.metric_names:
543
+ for state in self.states:
544
+ self.loss_meters[name][state].reset()
545
+
546
+
547
+ def flatten_values(self):
548
+ flat_dict = {}
549
+ for name in self.metric_names:
550
+ for state in self.states:
551
+ for type_ in self.types:
552
+ value_key = f"{name}_{state}_{type_}"
553
+ epoch_key = f"{name}_{state}_{type_}_epoch"
554
+ flat_dict[value_key] = self.values[name][state][type_]['value']
555
+ flat_dict[epoch_key] = self.values[name][state][type_]['epoch']
556
+ return flat_dict
557
+
558
+ def update_and_plot(self, name, epoch, save_path):
559
+ new_best_train = self.update_values(name, 'train', epoch)
560
+ new_best_val = self.update_values(name, 'val', epoch)
561
+
562
+
563
+ self.train_history[name].append(self.loss_meters[name]['train'].avg)
564
+ self.val_history[name].append(self.loss_meters[name]['val'].avg)
565
+
566
+
567
+ train_values = self.train_history[name]
568
+ val_values = self.val_history[name]
569
+ epochs = list(range(1, len(train_values) + 1))
570
+
571
+
572
+ plt.figure(figsize=(10, 6))
573
+ plt.plot(epochs, train_values, label='Train')
574
+ plt.plot(epochs, val_values, label='Val')
575
+ plt.title(f'Train vs Val {name} over epochs')
576
+ plt.xlabel('Epochs')
577
+ plt.ylabel(name)
578
+ plt.legend()
579
+ plt.savefig(save_path)
580
+ plt.close()
581
+
582
+
583
+ return new_best_train, new_best_val
584
+
585
+
586
+
587
+
588
+ def record_trial(args, tracker):
589
+ """
590
+ 1. record notes, score, env_name, experments_path,
591
+ """
592
+ csv_path = args.out_path + "custom/" +args.csv_name+".csv"
593
+ all_print_dict = vars(args)
594
+ all_print_dict.update(tracker.flatten_values())
595
+ if not os.path.exists(csv_path):
596
+ pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False)
597
+ else:
598
+ df_existing = pd.read_csv(csv_path)
599
+ df_new = pd.DataFrame([all_print_dict])
600
+ df_aligned = df_existing.append(df_new).fillna("")
601
+ df_aligned.to_csv(csv_path, index=False)
602
+
603
+
604
+ def set_random_seed(args):
605
+ os.environ['PYTHONHASHSEED'] = str(args.random_seed)
606
+ random.seed(args.random_seed)
607
+ np.random.seed(args.random_seed)
608
+ torch.manual_seed(args.random_seed)
609
+ torch.cuda.manual_seed_all(args.random_seed)
610
+ torch.cuda.manual_seed(args.random_seed)
611
+ torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC
612
+ torch.backends.cudnn.benchmark = args.benchmark
613
+ torch.backends.cudnn.enabled = args.cudnn_enabled
614
+
615
+
616
+ def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None):
617
+ if lrs is not None:
618
+ states = { 'model_state': model.state_dict(),
619
+ 'epoch': epoch + 1,
620
+ 'opt_state': opt.state_dict(),
621
+ 'lrs':lrs.state_dict(),}
622
+ elif opt is not None:
623
+ states = { 'model_state': model.state_dict(),
624
+ 'epoch': epoch + 1,
625
+ 'opt_state': opt.state_dict(),}
626
+ else:
627
+ states = { 'model_state': model.state_dict(),}
628
+ torch.save(states, save_path)
629
+
630
+
631
+ def load_checkpoints(model, save_path, load_name='model'):
632
+ states = torch.load(save_path)
633
+ new_weights = OrderedDict()
634
+ flag=False
635
+ for k, v in states['model_state'].items():
636
+ #print(k)
637
+ if "module" not in k:
638
+ break
639
+ else:
640
+ new_weights[k[7:]]=v
641
+ flag=True
642
+ if flag:
643
+ try:
644
+ model.load_state_dict(new_weights)
645
+ except:
646
+ #print(states['model_state'])
647
+ model.load_state_dict(states['model_state'])
648
+ else:
649
+ model.load_state_dict(states['model_state'])
650
+ logger.info(f"load self-pretrained checkpoints for {load_name}")
651
+
652
+
653
+ def model_complexity(model, args):
654
+ from ptflops import get_model_complexity_info
655
+ flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN),
656
+ as_strings=False, print_per_layer_stat=False)
657
+ logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9))
658
+ logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6))
659
+
660
+
661
+ class AverageMeter(object):
662
+ """Computes and stores the average and current value"""
663
+ def __init__(self, name, fmt=':f'):
664
+ self.name = name
665
+ self.fmt = fmt
666
+ self.reset()
667
+
668
+ def reset(self):
669
+ self.val = 0
670
+ self.avg = 0
671
+ self.sum = 0
672
+ self.count = 0
673
+
674
+ def update(self, val, n=1):
675
+ self.val = val
676
+ self.sum += val * n
677
+ self.count += n
678
+ self.avg = self.sum / self.count
679
+
680
+ def __str__(self):
681
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
682
+ return fmtstr.format(**self.__dict__)
683
+
684
+
685
+
686
+ class MultiLMDBManager:
687
+ def __init__(self, base_dir, max_db_size=10*1024*1024*1024): # 10GB default size
688
+ self.base_dir = base_dir
689
+ self.max_db_size = max_db_size
690
+ self.current_db_size = 0
691
+ self.current_db_idx = 0
692
+ self.current_lmdb_env = None
693
+ self.sample_to_db_mapping = {}
694
+ self.sample_counter = 0
695
+ self.db_paths = []
696
+
697
+ def get_new_lmdb_path(self):
698
+ db_path = os.path.join(self.base_dir, f"db_{self.current_db_idx:03d}")
699
+ self.db_paths.append(db_path)
700
+ return db_path
701
+
702
+ def init_new_db(self):
703
+ if self.current_lmdb_env is not None:
704
+ self.current_lmdb_env.sync()
705
+ self.current_lmdb_env.close()
706
+
707
+ new_db_path = self.get_new_lmdb_path()
708
+ self.current_lmdb_env = lmdb.open(new_db_path, map_size=self.max_db_size)
709
+ self.current_db_size = 0
710
+ self.current_db_idx += 1
711
+ return self.current_lmdb_env
712
+
713
+ def add_sample(self, sample_data):
714
+ if self.current_lmdb_env is None:
715
+ self.init_new_db()
716
+
717
+ v = pickle.dumps(sample_data)
718
+ sample_size = len(v)
719
+
720
+ try:
721
+ sample_key = "{:008d}".format(self.sample_counter).encode("ascii")
722
+ with self.current_lmdb_env.begin(write=True) as txn:
723
+ txn.put(sample_key, v)
724
+ self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1
725
+
726
+ except lmdb.MapFullError:
727
+ self.init_new_db()
728
+ sample_key = "{:008d}".format(self.sample_counter).encode("ascii")
729
+ with self.current_lmdb_env.begin(write=True) as txn:
730
+ txn.put(sample_key, v)
731
+ self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1
732
+
733
+ self.current_db_size += sample_size
734
+ self.sample_counter += 1
735
+
736
+ def save_mapping(self):
737
+ mapping_path = os.path.join(self.base_dir, "sample_db_mapping.pkl")
738
+ with open(mapping_path, 'wb') as f:
739
+ pickle.dump({
740
+ 'mapping': self.sample_to_db_mapping,
741
+ 'db_paths': self.db_paths
742
+ }, f)
743
+
744
+ def close(self):
745
+ if self.current_lmdb_env is not None:
746
+ self.current_lmdb_env.sync()
747
+ self.current_lmdb_env.close()
748
+ self.save_mapping()