Spaces:
Runtime error
Runtime error
Upload 149 files
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +5 -0
- configs/beat2_rvqvae.yaml +134 -0
- configs/diffuser_rvqvae_128.yaml +96 -0
- configs/model_config.yaml +71 -0
- configs/sc_model_config.yaml +37 -0
- configs/sc_model_holistic_config.yaml +37 -0
- configs/sc_reflow_model_config.yaml +37 -0
- configs/shortcut.yaml +96 -0
- configs/shortcut_hf.yaml +96 -0
- configs/shortcut_holistic.yaml +96 -0
- configs/shortcut_reflow.yaml +96 -0
- configs/shortcut_reflow_test.yaml +96 -0
- configs/shortcut_rvqvae_128.yaml +96 -0
- configs/shortcut_rvqvae_128_hf.yaml +96 -0
- dataloaders/__pycache__/beat_sep_single.cpython-312.pyc +0 -0
- dataloaders/__pycache__/build_vocab.cpython-312.pyc +0 -0
- dataloaders/__pycache__/data_tools.cpython-312.pyc +0 -0
- dataloaders/beat_dataset_new.py +373 -0
- dataloaders/beat_sep.py +772 -0
- dataloaders/beat_sep_lower.py +430 -0
- dataloaders/beat_sep_single.py +693 -0
- dataloaders/beat_smplx2020.py +763 -0
- dataloaders/build_vocab.py +199 -0
- dataloaders/data_tools.py +1756 -0
- dataloaders/mix_sep.py +301 -0
- dataloaders/pymo/Quaternions.py +468 -0
- dataloaders/pymo/__init__.py +0 -0
- dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/__init__.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/data.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/parsers.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc +0 -0
- dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc +0 -0
- dataloaders/pymo/data.py +53 -0
- dataloaders/pymo/features.py +43 -0
- dataloaders/pymo/parsers.py +274 -0
- dataloaders/pymo/preprocessing.py +726 -0
- dataloaders/pymo/rotation_tools.py +153 -0
- dataloaders/pymo/rotation_tools.py! +69 -0
- dataloaders/pymo/viz_tools.py +236 -0
- dataloaders/pymo/writers.py +55 -0
- dataloaders/utils/__pycache__/audio_features.cpython-312.pyc +0 -0
- dataloaders/utils/__pycache__/other_tools.cpython-312.pyc +0 -0
- dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc +0 -0
- dataloaders/utils/audio_features.py +80 -0
- dataloaders/utils/data_sample.py +175 -0
- dataloaders/utils/mis_features.py +64 -0
- dataloaders/utils/motion_rep_transfer.py +236 -0
- dataloaders/utils/other_tools.py +748 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/examples/2_scott_0_1_1.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo/examples/2_scott_0_2_2.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo/examples/2_scott_0_3_3.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
demo/examples/2_scott_0_4_4.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
demo/examples/2_scott_0_5_5.wav filter=lfs diff=lfs merge=lfs -text
|
configs/beat2_rvqvae.yaml
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 8 |
+
e_path: weights/AESKConv_240_100.bin
|
| 9 |
+
eval_model: motion_representation
|
| 10 |
+
e_name: VAESKConv
|
| 11 |
+
test_ckpt: ./outputs/audio2pose/custom/0112_001634_emage/last_500.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
|
| 14 |
+
vae_test_len: 32
|
| 15 |
+
vae_test_dim: 330
|
| 16 |
+
vae_test_stride: 20
|
| 17 |
+
vae_length: 240
|
| 18 |
+
vae_codebook_size: 256
|
| 19 |
+
vae_layer: 4
|
| 20 |
+
vae_grow: [1,1,2,1]
|
| 21 |
+
variational: False
|
| 22 |
+
|
| 23 |
+
# data config
|
| 24 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] #[2]
|
| 25 |
+
additional_data: False
|
| 26 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_rvqvae/
|
| 27 |
+
dataset: mix_sep
|
| 28 |
+
new_cache: True
|
| 29 |
+
use_amass: False
|
| 30 |
+
# motion config
|
| 31 |
+
ori_joints: beat_smplx_joints
|
| 32 |
+
tar_joints: beat_smplx_full
|
| 33 |
+
pose_rep: smplxflame_30
|
| 34 |
+
pose_norm: False
|
| 35 |
+
pose_fps: 30
|
| 36 |
+
rot6d: True
|
| 37 |
+
pre_frames: 4
|
| 38 |
+
pose_dims: 330
|
| 39 |
+
pose_length: 64
|
| 40 |
+
stride: 20
|
| 41 |
+
test_length: 64
|
| 42 |
+
motion_f: 256
|
| 43 |
+
m_pre_encoder: null
|
| 44 |
+
m_encoder: null
|
| 45 |
+
m_fix_pre: False
|
| 46 |
+
|
| 47 |
+
# audio config
|
| 48 |
+
audio_rep: onset+amplitude
|
| 49 |
+
audio_sr: 16000
|
| 50 |
+
audio_fps: 16000
|
| 51 |
+
audio_norm: False
|
| 52 |
+
audio_f: 256
|
| 53 |
+
# a_pre_encoder: tcn_camn
|
| 54 |
+
# a_encoder: none
|
| 55 |
+
# a_fix_pre: False
|
| 56 |
+
|
| 57 |
+
# text config
|
| 58 |
+
word_rep: textgrid
|
| 59 |
+
word_index_num: 11195
|
| 60 |
+
word_dims: 300
|
| 61 |
+
freeze_wordembed: False
|
| 62 |
+
word_f: 256
|
| 63 |
+
t_pre_encoder: fasttext
|
| 64 |
+
t_encoder: null
|
| 65 |
+
t_fix_pre: False
|
| 66 |
+
|
| 67 |
+
# facial config
|
| 68 |
+
facial_rep: smplxflame_30
|
| 69 |
+
facial_dims: 100
|
| 70 |
+
facial_norm: False
|
| 71 |
+
facial_f: 0
|
| 72 |
+
f_pre_encoder: null
|
| 73 |
+
f_encoder: null
|
| 74 |
+
f_fix_pre: False
|
| 75 |
+
|
| 76 |
+
# speaker config
|
| 77 |
+
id_rep: onehot
|
| 78 |
+
speaker_f: 0
|
| 79 |
+
|
| 80 |
+
# model config
|
| 81 |
+
batch_size: 80 #80
|
| 82 |
+
# warmup_epochs: 1
|
| 83 |
+
# warmup_lr: 1e-6
|
| 84 |
+
lr_base: 4e-4
|
| 85 |
+
model: motion_representation
|
| 86 |
+
g_name: VQVAEConvZero
|
| 87 |
+
trainer: ae_total
|
| 88 |
+
hidden_size: 768
|
| 89 |
+
n_layer: 1
|
| 90 |
+
|
| 91 |
+
rec_weight: 1
|
| 92 |
+
grad_norm: 0.99
|
| 93 |
+
epochs: 200
|
| 94 |
+
test_period: 20
|
| 95 |
+
ll: 3
|
| 96 |
+
lf: 3
|
| 97 |
+
lu: 3
|
| 98 |
+
lh: 3
|
| 99 |
+
cl: 1
|
| 100 |
+
cf: 0
|
| 101 |
+
cu: 1
|
| 102 |
+
ch: 1
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#below is vavae config, copy from QPGESTURE
|
| 107 |
+
#Codebook Configs
|
| 108 |
+
levels: 1
|
| 109 |
+
downs_t: [3]
|
| 110 |
+
strides_t : [2]
|
| 111 |
+
emb_width : 512
|
| 112 |
+
l_bins : 512
|
| 113 |
+
l_mu : 0.99
|
| 114 |
+
commit : 0.1
|
| 115 |
+
hvqvae_multipliers : [1]
|
| 116 |
+
width: 512
|
| 117 |
+
depth: 3
|
| 118 |
+
m_conv : 1.0
|
| 119 |
+
dilation_growth_rate : 3
|
| 120 |
+
sample_length: 80
|
| 121 |
+
use_bottleneck: True
|
| 122 |
+
joint_channel: 6
|
| 123 |
+
# depth: 3
|
| 124 |
+
# width: 128
|
| 125 |
+
# m_conv: 1.0
|
| 126 |
+
# dilation_growth_rate: 1
|
| 127 |
+
# dilation_cycle: None
|
| 128 |
+
vel: 1 # 1 -> 0
|
| 129 |
+
acc: 1 # 1 -> 0
|
| 130 |
+
vqvae_reverse_decoder_dilation: True
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## below is special for emage
|
| 134 |
+
rec_pos_weight : 1.0
|
configs/diffuser_rvqvae_128.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_diffusion.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_lower
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: diffuser_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/model_config.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_name: GestureDiffuse
|
| 3 |
+
g_name: GestureDiffusion
|
| 4 |
+
do_classifier_free_guidance: False
|
| 5 |
+
guidance_scale: 1.5
|
| 6 |
+
|
| 7 |
+
denoiser:
|
| 8 |
+
target: models.denoiser.GestureDenoiser
|
| 9 |
+
params:
|
| 10 |
+
input_dim: 128
|
| 11 |
+
latent_dim: 256
|
| 12 |
+
ff_size: 1024
|
| 13 |
+
num_layers: 8
|
| 14 |
+
num_heads: 4
|
| 15 |
+
dropout: 0.1
|
| 16 |
+
activation: "gelu"
|
| 17 |
+
n_seed: 8
|
| 18 |
+
flip_sin_to_cos: True
|
| 19 |
+
freq_shift: 0.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
modality_encoder:
|
| 24 |
+
target: models.modality_encoder.ModalityEncoder
|
| 25 |
+
params:
|
| 26 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 27 |
+
t_fix_pre: False
|
| 28 |
+
audio_dim: 256
|
| 29 |
+
audio_in: 2
|
| 30 |
+
raw_audio: False
|
| 31 |
+
latent_dim: 256
|
| 32 |
+
audio_fps: 30
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
scheduler:
|
| 36 |
+
target: diffusers.DDIMScheduler
|
| 37 |
+
num_inference_steps: 20
|
| 38 |
+
eta: 0.0
|
| 39 |
+
params:
|
| 40 |
+
num_train_timesteps: 1000
|
| 41 |
+
# if using 'linear or 'scaled_linear', beta_start and beta_end matters, if cosine, beta_start and beta_end are ignored
|
| 42 |
+
beta_start: 0.00085
|
| 43 |
+
beta_end: 0.012
|
| 44 |
+
# 'linear' or 'squaredcos_cap_v2' or 'scaled_linear'
|
| 45 |
+
beta_schedule: 'squaredcos_cap_v2'
|
| 46 |
+
prediction_type: 'sample'
|
| 47 |
+
clip_sample: false
|
| 48 |
+
# 'leading' or 'trailing' or 'linspace'
|
| 49 |
+
timestep_spacing: 'leading'
|
| 50 |
+
# below are for ddim
|
| 51 |
+
set_alpha_to_one: True
|
| 52 |
+
steps_offset: 0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# use ddpm scheduler
|
| 56 |
+
# scheduler:
|
| 57 |
+
# target: diffusers.DDPMScheduler
|
| 58 |
+
# num_inference_steps: 50
|
| 59 |
+
# eta: 0.0
|
| 60 |
+
# params:
|
| 61 |
+
# num_train_timesteps: 1000
|
| 62 |
+
# beta_start: 0.00085
|
| 63 |
+
# beta_end: 0.012
|
| 64 |
+
# beta_schedule: 'squaredcos_cap_v2' # 'squaredcos_cap_v2'
|
| 65 |
+
# prediction_type: 'sample'
|
| 66 |
+
# clip_sample: false
|
| 67 |
+
# variance_type: 'fixed_small_log'
|
| 68 |
+
# # below are for ddim
|
| 69 |
+
# # set_alpha_to_one: True
|
| 70 |
+
# # steps_offset: 1
|
| 71 |
+
|
configs/sc_model_config.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_name: LSM
|
| 3 |
+
g_name: GestureLSM
|
| 4 |
+
do_classifier_free_guidance: False
|
| 5 |
+
guidance_scale: 2
|
| 6 |
+
n_steps: 20
|
| 7 |
+
use_exp: False
|
| 8 |
+
|
| 9 |
+
denoiser:
|
| 10 |
+
target: models.denoiser.GestureDenoiser
|
| 11 |
+
params:
|
| 12 |
+
input_dim: 128
|
| 13 |
+
latent_dim: 256
|
| 14 |
+
ff_size: 1024
|
| 15 |
+
num_layers: 8
|
| 16 |
+
num_heads: 4
|
| 17 |
+
dropout: 0.1
|
| 18 |
+
activation: "gelu"
|
| 19 |
+
n_seed: 8
|
| 20 |
+
flip_sin_to_cos: True
|
| 21 |
+
freq_shift: 0.0
|
| 22 |
+
cond_proj_dim: 256
|
| 23 |
+
use_exp: ${model.use_exp}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
modality_encoder:
|
| 27 |
+
target: models.modality_encoder.ModalityEncoder
|
| 28 |
+
params:
|
| 29 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 30 |
+
t_fix_pre: False
|
| 31 |
+
audio_dim: 256
|
| 32 |
+
audio_in: 2
|
| 33 |
+
raw_audio: False
|
| 34 |
+
latent_dim: 256
|
| 35 |
+
audio_fps: 30
|
| 36 |
+
use_exp: ${model.use_exp}
|
| 37 |
+
|
configs/sc_model_holistic_config.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_name: LSM
|
| 3 |
+
g_name: GestureLSM
|
| 4 |
+
do_classifier_free_guidance: False
|
| 5 |
+
guidance_scale: 2
|
| 6 |
+
n_steps: 25
|
| 7 |
+
use_exp: True
|
| 8 |
+
|
| 9 |
+
denoiser:
|
| 10 |
+
target: models.denoiser.GestureDenoiser
|
| 11 |
+
params:
|
| 12 |
+
input_dim: 128
|
| 13 |
+
latent_dim: 256
|
| 14 |
+
ff_size: 1024
|
| 15 |
+
num_layers: 8
|
| 16 |
+
num_heads: 4
|
| 17 |
+
dropout: 0.1
|
| 18 |
+
activation: "gelu"
|
| 19 |
+
n_seed: 8
|
| 20 |
+
flip_sin_to_cos: True
|
| 21 |
+
freq_shift: 0.0
|
| 22 |
+
cond_proj_dim: 256
|
| 23 |
+
use_exp: ${model.use_exp}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
modality_encoder:
|
| 27 |
+
target: models.modality_encoder.ModalityEncoder
|
| 28 |
+
params:
|
| 29 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 30 |
+
t_fix_pre: False
|
| 31 |
+
audio_dim: 256
|
| 32 |
+
audio_in: 2
|
| 33 |
+
raw_audio: False
|
| 34 |
+
latent_dim: 256
|
| 35 |
+
audio_fps: 30
|
| 36 |
+
use_exp: ${model.use_exp}
|
| 37 |
+
|
configs/sc_reflow_model_config.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_name: LSM
|
| 3 |
+
g_name: GestureLSM
|
| 4 |
+
do_classifier_free_guidance: False
|
| 5 |
+
guidance_scale: 2
|
| 6 |
+
n_steps: 2
|
| 7 |
+
use_exp: False
|
| 8 |
+
|
| 9 |
+
denoiser:
|
| 10 |
+
target: models.denoiser.GestureDenoiser
|
| 11 |
+
params:
|
| 12 |
+
input_dim: 128
|
| 13 |
+
latent_dim: 256
|
| 14 |
+
ff_size: 1024
|
| 15 |
+
num_layers: 8
|
| 16 |
+
num_heads: 4
|
| 17 |
+
dropout: 0.1
|
| 18 |
+
activation: "gelu"
|
| 19 |
+
n_seed: 8
|
| 20 |
+
flip_sin_to_cos: True
|
| 21 |
+
freq_shift: 0.0
|
| 22 |
+
cond_proj_dim: 256
|
| 23 |
+
use_exp: ${model.use_exp}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
modality_encoder:
|
| 27 |
+
target: models.modality_encoder.ModalityEncoder
|
| 28 |
+
params:
|
| 29 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 30 |
+
t_fix_pre: False
|
| 31 |
+
audio_dim: 256
|
| 32 |
+
audio_in: 2
|
| 33 |
+
raw_audio: False
|
| 34 |
+
latent_dim: 256
|
| 35 |
+
audio_fps: 30
|
| 36 |
+
use_exp: ${model.use_exp}
|
| 37 |
+
|
configs/shortcut.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_shortcut.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
vqvae_face_path: ./ckpt/net_300000_face.pth
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_lower
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_hf.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_shortcut.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_single
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_holistic.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_shortcut_holistic.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_holistic_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
vqvae_face_path: ./ckpt/net_300000_face.pth
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_lower
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_reflow.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./outputs/audio2pose/custom/0212_125039_shortcut_reflow/last_20.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
vqvae_face_path: ./ckpt/net_300000_face.pth
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_reflow
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 1
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_reflow_test.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/shortcut_reflow.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_reflow_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
vqvae_face_path: ./ckpt/net_300000_face.pth
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_lower
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 1
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_rvqvae_128.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_shortcut.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_lower
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
configs/shortcut_rvqvae_128_hf.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_train: True
|
| 2 |
+
ddp: False
|
| 3 |
+
stat: ts
|
| 4 |
+
root_path: ./
|
| 5 |
+
out_path: ./outputs/audio2pose/
|
| 6 |
+
project: s2g
|
| 7 |
+
e_path: weights/AESKConv_240_100.bin
|
| 8 |
+
eval_model: motion_representation
|
| 9 |
+
e_name: VAESKConv
|
| 10 |
+
data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/
|
| 11 |
+
test_ckpt: ./ckpt/new_540_shortcut.bin
|
| 12 |
+
data_path_1: ./datasets/hub/
|
| 13 |
+
pose_norm: True
|
| 14 |
+
cfg: configs/sc_model_config.yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
mean_pose_path: ./mean_std/beatx_2_330_mean.npy
|
| 18 |
+
std_pose_path: ./mean_std/beatx_2_330_std.npy
|
| 19 |
+
|
| 20 |
+
mean_trans_path: ./mean_std/beatx_2_trans_mean.npy
|
| 21 |
+
std_trans_path: ./mean_std/beatx_2_trans_std.npy
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
vqvae_upper_path: ./ckpt/net_300000_upper.pth
|
| 25 |
+
vqvae_hands_path: ./ckpt/net_300000_hands.pth
|
| 26 |
+
vqvae_lower_path: ./ckpt/net_300000_lower.pth
|
| 27 |
+
|
| 28 |
+
vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth
|
| 29 |
+
use_trans: True
|
| 30 |
+
|
| 31 |
+
decay_epoch: 500
|
| 32 |
+
|
| 33 |
+
vqvae_squeeze_scale: 4
|
| 34 |
+
vqvae_latent_scale: 5
|
| 35 |
+
|
| 36 |
+
vae_test_len: 32
|
| 37 |
+
vae_test_dim: 330
|
| 38 |
+
vae_test_stride: 20
|
| 39 |
+
vae_length: 240
|
| 40 |
+
vae_codebook_size: 256
|
| 41 |
+
vae_layer: 4
|
| 42 |
+
vae_grow: [1,1,2,1]
|
| 43 |
+
variational: False
|
| 44 |
+
|
| 45 |
+
# data config
|
| 46 |
+
training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
|
| 47 |
+
additional_data: False
|
| 48 |
+
cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/
|
| 49 |
+
dataset: beat_sep_single
|
| 50 |
+
new_cache: False
|
| 51 |
+
|
| 52 |
+
# motion config
|
| 53 |
+
ori_joints: beat_smplx_joints
|
| 54 |
+
tar_joints: beat_smplx_full
|
| 55 |
+
pose_rep: smplxflame_30
|
| 56 |
+
pose_fps: 30
|
| 57 |
+
rot6d: True
|
| 58 |
+
pre_frames: 4
|
| 59 |
+
pose_dims: 330
|
| 60 |
+
pose_length: 128
|
| 61 |
+
stride: 20
|
| 62 |
+
test_length: 128
|
| 63 |
+
m_fix_pre: False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
audio_rep: onset+amplitude
|
| 67 |
+
audio_sr: 16000
|
| 68 |
+
audio_fps: 16000
|
| 69 |
+
audio_norm: False
|
| 70 |
+
audio_f: 256
|
| 71 |
+
audio_raw: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
word_rep: textgrid
|
| 75 |
+
word_dims: 300
|
| 76 |
+
t_pre_encoder: fasttext
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
facial_rep: smplxflame_30
|
| 80 |
+
facial_dims: 100
|
| 81 |
+
facial_norm: False
|
| 82 |
+
facial_f: 0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
id_rep: onehot
|
| 86 |
+
speaker_f: 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
batch_size: 128
|
| 90 |
+
lr_base: 2e-4
|
| 91 |
+
trainer: shortcut_rvqvae
|
| 92 |
+
|
| 93 |
+
rec_weight: 1
|
| 94 |
+
grad_norm: 0.99
|
| 95 |
+
epochs: 1000
|
| 96 |
+
test_period: 20
|
dataloaders/__pycache__/beat_sep_single.cpython-312.pyc
ADDED
|
Binary file (42 kB). View file
|
|
|
dataloaders/__pycache__/build_vocab.cpython-312.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
dataloaders/__pycache__/data_tools.cpython-312.pyc
ADDED
|
Binary file (43.4 kB). View file
|
|
|
dataloaders/beat_dataset_new.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import textgrid as tg
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import glob
|
| 11 |
+
import json
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import pickle
|
| 18 |
+
import smplx
|
| 19 |
+
from .utils.audio_features import AudioProcessor
|
| 20 |
+
from .utils.other_tools import MultiLMDBManager
|
| 21 |
+
from .utils.motion_rep_transfer import process_smplx_motion
|
| 22 |
+
from .utils.mis_features import process_semantic_data, process_emotion_data
|
| 23 |
+
from .utils.text_features import process_word_data
|
| 24 |
+
from .utils.data_sample import sample_from_clip
|
| 25 |
+
from .utils import rotation_conversions as rc
|
| 26 |
+
|
| 27 |
+
class CustomDataset(Dataset):
|
| 28 |
+
def __init__(self, args, loader_type, build_cache=True):
|
| 29 |
+
self.args = args
|
| 30 |
+
self.loader_type = loader_type
|
| 31 |
+
self.rank = dist.get_rank()
|
| 32 |
+
|
| 33 |
+
self.ori_stride = self.args.stride
|
| 34 |
+
self.ori_length = self.args.pose_length
|
| 35 |
+
|
| 36 |
+
# Initialize basic parameters
|
| 37 |
+
self.ori_stride = self.args.stride
|
| 38 |
+
self.ori_length = self.args.pose_length
|
| 39 |
+
self.alignment = [0,0] # for trinity
|
| 40 |
+
|
| 41 |
+
# Initialize SMPLX model
|
| 42 |
+
self.smplx = smplx.create(
|
| 43 |
+
self.args.data_path_1+"smplx_models/",
|
| 44 |
+
model_type='smplx',
|
| 45 |
+
gender='NEUTRAL_2020',
|
| 46 |
+
use_face_contour=False,
|
| 47 |
+
num_betas=300,
|
| 48 |
+
num_expression_coeffs=100,
|
| 49 |
+
ext='npz',
|
| 50 |
+
use_pca=False,
|
| 51 |
+
).cuda().eval()
|
| 52 |
+
|
| 53 |
+
self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 54 |
+
|
| 55 |
+
# Load and process split rules
|
| 56 |
+
self._process_split_rules()
|
| 57 |
+
|
| 58 |
+
# Initialize data directories and lengths
|
| 59 |
+
self._init_data_paths()
|
| 60 |
+
|
| 61 |
+
# Build or load cache
|
| 62 |
+
self._init_cache(build_cache)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _process_split_rules(self):
|
| 66 |
+
"""Process dataset split rules."""
|
| 67 |
+
split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv")
|
| 68 |
+
self.selected_file = split_rule.loc[
|
| 69 |
+
(split_rule['type'] == self.loader_type) &
|
| 70 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
if self.args.additional_data and self.loader_type == 'train':
|
| 74 |
+
split_b = split_rule.loc[
|
| 75 |
+
(split_rule['type'] == 'additional') &
|
| 76 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 77 |
+
]
|
| 78 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 79 |
+
|
| 80 |
+
if self.selected_file.empty:
|
| 81 |
+
logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 82 |
+
self.selected_file = split_rule.loc[
|
| 83 |
+
(split_rule['type'] == 'train') &
|
| 84 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 85 |
+
]
|
| 86 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 87 |
+
|
| 88 |
+
def _init_data_paths(self):
|
| 89 |
+
"""Initialize data directories and lengths."""
|
| 90 |
+
self.data_dir = self.args.data_path
|
| 91 |
+
|
| 92 |
+
if self.loader_type == "test":
|
| 93 |
+
self.args.multi_length_training = [1.0]
|
| 94 |
+
|
| 95 |
+
self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1])
|
| 96 |
+
self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr)
|
| 97 |
+
|
| 98 |
+
if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr:
|
| 99 |
+
self.max_audio_pre_len = self.args.test_length * self.args.audio_sr
|
| 100 |
+
|
| 101 |
+
if self.args.test_clip and self.loader_type == "test":
|
| 102 |
+
self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache"
|
| 103 |
+
else:
|
| 104 |
+
self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _init_cache(self, build_cache):
|
| 109 |
+
"""Initialize or build cache."""
|
| 110 |
+
self.lmdb_envs = {}
|
| 111 |
+
self.mapping_data = None
|
| 112 |
+
|
| 113 |
+
if build_cache and self.rank == 0:
|
| 114 |
+
self.build_cache(self.preloaded_dir)
|
| 115 |
+
|
| 116 |
+
self.load_db_mapping()
|
| 117 |
+
|
| 118 |
+
def build_cache(self, preloaded_dir):
|
| 119 |
+
"""Build the dataset cache."""
|
| 120 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 121 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 122 |
+
logger.info("Creating the dataset cache...")
|
| 123 |
+
|
| 124 |
+
if self.args.new_cache and os.path.exists(preloaded_dir):
|
| 125 |
+
shutil.rmtree(preloaded_dir)
|
| 126 |
+
|
| 127 |
+
if os.path.exists(preloaded_dir):
|
| 128 |
+
# if the dir is empty, that means we still need to build the cache
|
| 129 |
+
if not os.listdir(preloaded_dir):
|
| 130 |
+
self.cache_generation(
|
| 131 |
+
preloaded_dir,
|
| 132 |
+
self.args.disable_filtering,
|
| 133 |
+
self.args.clean_first_seconds,
|
| 134 |
+
self.args.clean_final_seconds,
|
| 135 |
+
is_test=False
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 139 |
+
|
| 140 |
+
elif self.loader_type == "test":
|
| 141 |
+
self.cache_generation(preloaded_dir, True, 0, 0, is_test=True)
|
| 142 |
+
else:
|
| 143 |
+
self.cache_generation(
|
| 144 |
+
preloaded_dir,
|
| 145 |
+
self.args.disable_filtering,
|
| 146 |
+
self.args.clean_first_seconds,
|
| 147 |
+
self.args.clean_final_seconds,
|
| 148 |
+
is_test=False
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 152 |
+
"""Generate cache for the dataset."""
|
| 153 |
+
if not os.path.exists(out_lmdb_dir):
|
| 154 |
+
os.makedirs(out_lmdb_dir)
|
| 155 |
+
|
| 156 |
+
self.audio_processor = AudioProcessor(layer=self.args.n_layer, use_distill=self.args.use_distill)
|
| 157 |
+
|
| 158 |
+
# Initialize the multi-LMDB manager
|
| 159 |
+
lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024)
|
| 160 |
+
|
| 161 |
+
self.n_out_samples = 0
|
| 162 |
+
n_filtered_out = defaultdict(int)
|
| 163 |
+
|
| 164 |
+
for index, file_name in self.selected_file.iterrows():
|
| 165 |
+
f_name = file_name["id"]
|
| 166 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 167 |
+
pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext)
|
| 168 |
+
|
| 169 |
+
# Process data
|
| 170 |
+
data = self._process_file_data(f_name, pose_file, ext)
|
| 171 |
+
if data is None:
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
# Sample from clip
|
| 175 |
+
filtered_result, self.n_out_samples = sample_from_clip(
|
| 176 |
+
lmdb_manager=lmdb_manager,
|
| 177 |
+
audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"),
|
| 178 |
+
audio_each_file=data['audio_tensor'],
|
| 179 |
+
high_each_file=data['high_level'],
|
| 180 |
+
low_each_file=data['low_level'],
|
| 181 |
+
pose_each_file=data['pose'],
|
| 182 |
+
rep15d_each_file=data['rep15d'],
|
| 183 |
+
trans_each_file=data['trans'],
|
| 184 |
+
trans_v_each_file=data['trans_v'],
|
| 185 |
+
shape_each_file=data['shape'],
|
| 186 |
+
facial_each_file=data['facial'],
|
| 187 |
+
aligned_text_each_file=data['aligned_text'],
|
| 188 |
+
word_each_file=data['word'] if self.args.word_rep is not None else None,
|
| 189 |
+
vid_each_file=data['vid'],
|
| 190 |
+
emo_each_file=data['emo'],
|
| 191 |
+
sem_each_file=data['sem'],
|
| 192 |
+
intention_each_file=data['intention'] if data['intention'] is not None else None,
|
| 193 |
+
audio_onset_each_file=data['audio_onset'] if self.args.onset_rep else None,
|
| 194 |
+
args=self.args,
|
| 195 |
+
ori_stride=self.ori_stride,
|
| 196 |
+
ori_length=self.ori_length,
|
| 197 |
+
disable_filtering=disable_filtering,
|
| 198 |
+
clean_first_seconds=clean_first_seconds,
|
| 199 |
+
clean_final_seconds=clean_final_seconds,
|
| 200 |
+
is_test=is_test,
|
| 201 |
+
n_out_samples=self.n_out_samples
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
for type_key in filtered_result:
|
| 205 |
+
n_filtered_out[type_key] += filtered_result[type_key]
|
| 206 |
+
|
| 207 |
+
lmdb_manager.close()
|
| 208 |
+
|
| 209 |
+
def _process_file_data(self, f_name, pose_file, ext):
|
| 210 |
+
"""Process all data for a single file."""
|
| 211 |
+
data = {
|
| 212 |
+
'pose': None, 'trans': None, 'trans_v': None, 'shape': None,
|
| 213 |
+
'audio': None, 'facial': None, 'word': None, 'emo': None,
|
| 214 |
+
'sem': None, 'vid': None
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
# Process motion data
|
| 218 |
+
logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue"))
|
| 219 |
+
if "smplx" in self.args.pose_rep:
|
| 220 |
+
motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep)
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.")
|
| 223 |
+
|
| 224 |
+
if motion_data is None:
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
data.update(motion_data)
|
| 228 |
+
|
| 229 |
+
# Process speaker ID
|
| 230 |
+
if self.args.id_rep is not None:
|
| 231 |
+
speaker_id = int(f_name.split("_")[0]) - 1
|
| 232 |
+
data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0)
|
| 233 |
+
else:
|
| 234 |
+
data['vid'] = np.array([-1])
|
| 235 |
+
|
| 236 |
+
# Process audio if needed
|
| 237 |
+
if self.args.audio_rep is not None:
|
| 238 |
+
audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
|
| 239 |
+
audio_data = self.audio_processor.get_wav2vec_from_16k_wav(audio_file, aligned_text=True)
|
| 240 |
+
if audio_data is None:
|
| 241 |
+
return None
|
| 242 |
+
data.update(audio_data)
|
| 243 |
+
|
| 244 |
+
if getattr(self.args, "onset_rep", False):
|
| 245 |
+
audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
|
| 246 |
+
onset_data = self.audio_processor.calculate_onset_amplitude(audio_file, data)
|
| 247 |
+
if onset_data is None:
|
| 248 |
+
return None
|
| 249 |
+
data.update(onset_data)
|
| 250 |
+
|
| 251 |
+
# Process emotion if needed
|
| 252 |
+
if self.args.emo_rep is not None:
|
| 253 |
+
data = process_emotion_data(f_name, data, self.args)
|
| 254 |
+
if data is None:
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
# Process word data if needed
|
| 258 |
+
if self.args.word_rep is not None:
|
| 259 |
+
word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid"
|
| 260 |
+
data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file)
|
| 261 |
+
if data is None:
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Process semantic data if needed
|
| 266 |
+
if self.args.sem_rep is not None:
|
| 267 |
+
sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt"
|
| 268 |
+
data = process_semantic_data(sem_file, self.args, data, f_name)
|
| 269 |
+
if data is None:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
return data
|
| 273 |
+
|
| 274 |
+
def load_db_mapping(self):
|
| 275 |
+
"""Load database mapping from file."""
|
| 276 |
+
mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
|
| 277 |
+
with open(mapping_path, 'rb') as f:
|
| 278 |
+
self.mapping_data = pickle.load(f)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# Update paths from test to test_clip if needed
|
| 282 |
+
if self.loader_type == "test" and self.args.test_clip:
|
| 283 |
+
updated_paths = []
|
| 284 |
+
for path in self.mapping_data['db_paths']:
|
| 285 |
+
updated_path = path.replace("test/", "test_clip/")
|
| 286 |
+
updated_paths.append(updated_path)
|
| 287 |
+
self.mapping_data['db_paths'] = updated_paths
|
| 288 |
+
|
| 289 |
+
# Re-save the updated mapping_data to the same pickle file
|
| 290 |
+
with open(mapping_path, 'wb') as f:
|
| 291 |
+
pickle.dump(self.mapping_data, f)
|
| 292 |
+
|
| 293 |
+
self.n_samples = len(self.mapping_data['mapping'])
|
| 294 |
+
|
| 295 |
+
def get_lmdb_env(self, db_idx):
|
| 296 |
+
"""Get LMDB environment for given database index."""
|
| 297 |
+
if db_idx not in self.lmdb_envs:
|
| 298 |
+
db_path = self.mapping_data['db_paths'][db_idx]
|
| 299 |
+
self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False)
|
| 300 |
+
return self.lmdb_envs[db_idx]
|
| 301 |
+
|
| 302 |
+
def __len__(self):
|
| 303 |
+
"""Return the total number of samples in the dataset."""
|
| 304 |
+
return self.n_samples
|
| 305 |
+
|
| 306 |
+
def __getitem__(self, idx):
|
| 307 |
+
"""Get a single sample from the dataset."""
|
| 308 |
+
db_idx = self.mapping_data['mapping'][idx]
|
| 309 |
+
lmdb_env = self.get_lmdb_env(db_idx)
|
| 310 |
+
|
| 311 |
+
with lmdb_env.begin(write=False) as txn:
|
| 312 |
+
key = "{:008d}".format(idx).encode("ascii")
|
| 313 |
+
sample = txn.get(key)
|
| 314 |
+
sample = pickle.loads(sample)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
tar_pose, in_audio, in_audio_high, in_audio_low, tar_rep15d, in_facial, in_shape, in_aligned_text, in_word, emo, sem, vid, trans, trans_v, intention, audio_name, audio_onset = sample
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# Convert data to tensors with appropriate types
|
| 321 |
+
processed_data = self._convert_to_tensors(
|
| 322 |
+
tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word,
|
| 323 |
+
emo, sem, vid, trans, trans_v, intention, audio_onset
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
processed_data['audio_name'] = audio_name
|
| 327 |
+
return processed_data
|
| 328 |
+
|
| 329 |
+
def _convert_to_tensors(self, tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word,
|
| 330 |
+
emo, sem, vid, trans, trans_v, intention=None, audio_onset=None):
|
| 331 |
+
"""Convert numpy arrays to tensors with appropriate types."""
|
| 332 |
+
data = {
|
| 333 |
+
'emo': torch.from_numpy(emo).int(),
|
| 334 |
+
'sem': torch.from_numpy(sem).float(),
|
| 335 |
+
'audio_tensor': torch.from_numpy(in_audio).float(),
|
| 336 |
+
'bert_time_aligned': torch.from_numpy(in_aligned_text).float()
|
| 337 |
+
}
|
| 338 |
+
tar_pose = torch.from_numpy(tar_pose).float()
|
| 339 |
+
|
| 340 |
+
if self.loader_type == "test":
|
| 341 |
+
data.update({
|
| 342 |
+
'pose': tar_pose,
|
| 343 |
+
'rep15d': torch.from_numpy(tar_rep15d).float(),
|
| 344 |
+
'trans': torch.from_numpy(trans).float(),
|
| 345 |
+
'trans_v': torch.from_numpy(trans_v).float(),
|
| 346 |
+
'facial': torch.from_numpy(in_facial).float(),
|
| 347 |
+
'id': torch.from_numpy(vid).float(),
|
| 348 |
+
'beta': torch.from_numpy(in_shape).float()
|
| 349 |
+
})
|
| 350 |
+
else:
|
| 351 |
+
data.update({
|
| 352 |
+
'pose': tar_pose,
|
| 353 |
+
'rep15d': torch.from_numpy(tar_rep15d).reshape((tar_rep15d.shape[0], -1)).float(),
|
| 354 |
+
'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(),
|
| 355 |
+
'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(),
|
| 356 |
+
'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(),
|
| 357 |
+
'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(),
|
| 358 |
+
'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# Handle audio onset
|
| 363 |
+
if audio_onset is not None:
|
| 364 |
+
data['audio_onset'] = torch.from_numpy(audio_onset).float()
|
| 365 |
+
else:
|
| 366 |
+
data['audio_onset'] = torch.tensor([-1])
|
| 367 |
+
|
| 368 |
+
if in_word is not None:
|
| 369 |
+
data['word'] = torch.from_numpy(in_word).int()
|
| 370 |
+
else:
|
| 371 |
+
data['word'] = torch.tensor([-1])
|
| 372 |
+
|
| 373 |
+
return data
|
dataloaders/beat_sep.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import textgrid as tg
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import glob
|
| 11 |
+
import json
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
#import pyarrow
|
| 18 |
+
import pickle
|
| 19 |
+
import librosa
|
| 20 |
+
import smplx
|
| 21 |
+
|
| 22 |
+
from .build_vocab import Vocab
|
| 23 |
+
from .utils.audio_features import Wav2Vec2Model
|
| 24 |
+
from .data_tools import joints_list
|
| 25 |
+
from .utils import rotation_conversions as rc
|
| 26 |
+
from .utils import other_tools
|
| 27 |
+
|
| 28 |
+
class CustomDataset(Dataset):
|
| 29 |
+
def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
|
| 30 |
+
self.args = args
|
| 31 |
+
self.loader_type = loader_type
|
| 32 |
+
|
| 33 |
+
self.rank = dist.get_rank()
|
| 34 |
+
self.ori_stride = self.args.stride
|
| 35 |
+
self.ori_length = self.args.pose_length
|
| 36 |
+
self.alignment = [0,0] # for trinity
|
| 37 |
+
|
| 38 |
+
self.ori_joint_list = joints_list[self.args.ori_joints]
|
| 39 |
+
self.tar_joint_list = joints_list[self.args.tar_joints]
|
| 40 |
+
if 'smplx' in self.args.pose_rep:
|
| 41 |
+
self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
|
| 42 |
+
self.joints = len(list(self.tar_joint_list.keys()))
|
| 43 |
+
for joint_name in self.tar_joint_list:
|
| 44 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 45 |
+
else:
|
| 46 |
+
self.joints = len(list(self.ori_joint_list.keys()))+1
|
| 47 |
+
self.joint_mask = np.zeros(self.joints*3)
|
| 48 |
+
for joint_name in self.tar_joint_list:
|
| 49 |
+
if joint_name == "Hips":
|
| 50 |
+
self.joint_mask[3:6] = 1
|
| 51 |
+
else:
|
| 52 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 53 |
+
# select trainable joints
|
| 54 |
+
|
| 55 |
+
split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
|
| 56 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 57 |
+
if args.additional_data and loader_type == 'train':
|
| 58 |
+
split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 59 |
+
#self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 60 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 61 |
+
if self.selected_file.empty:
|
| 62 |
+
logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 63 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 64 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 65 |
+
self.data_dir = args.data_path
|
| 66 |
+
|
| 67 |
+
if loader_type == "test":
|
| 68 |
+
self.args.multi_length_training = [1.0]
|
| 69 |
+
self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
|
| 70 |
+
self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
|
| 71 |
+
if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
|
| 72 |
+
self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
|
| 73 |
+
|
| 74 |
+
if args.word_rep is not None:
|
| 75 |
+
with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
|
| 76 |
+
self.lang_model = pickle.load(f)
|
| 77 |
+
|
| 78 |
+
preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
|
| 79 |
+
# if args.pose_norm:
|
| 80 |
+
# # careful for rotation vectors
|
| 81 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
|
| 82 |
+
# self.calculate_mean_pose()
|
| 83 |
+
# self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy")
|
| 84 |
+
# self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy")
|
| 85 |
+
# if args.audio_norm:
|
| 86 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"):
|
| 87 |
+
# self.calculate_mean_audio()
|
| 88 |
+
# self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy")
|
| 89 |
+
# self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy")
|
| 90 |
+
# if args.facial_norm:
|
| 91 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
|
| 92 |
+
# self.calculate_mean_face()
|
| 93 |
+
# self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")
|
| 94 |
+
# self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")
|
| 95 |
+
if self.args.beat_align:
|
| 96 |
+
if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
|
| 97 |
+
self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 98 |
+
self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 99 |
+
|
| 100 |
+
if build_cache and self.rank == 0:
|
| 101 |
+
self.build_cache(preloaded_dir)
|
| 102 |
+
self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
|
| 103 |
+
with self.lmdb_env.begin() as txn:
|
| 104 |
+
self.n_samples = txn.stat()["entries"]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def calculate_mean_velocity(self, save_path):
|
| 108 |
+
self.smplx = smplx.create(
|
| 109 |
+
self.args.data_path_1+"smplx_models/",
|
| 110 |
+
model_type='smplx',
|
| 111 |
+
gender='NEUTRAL_2020',
|
| 112 |
+
use_face_contour=False,
|
| 113 |
+
num_betas=300,
|
| 114 |
+
num_expression_coeffs=100,
|
| 115 |
+
ext='npz',
|
| 116 |
+
use_pca=False,
|
| 117 |
+
).cuda().eval()
|
| 118 |
+
dir_p = self.data_dir + self.args.pose_rep + "/"
|
| 119 |
+
all_list = []
|
| 120 |
+
from tqdm import tqdm
|
| 121 |
+
for tar in tqdm(os.listdir(dir_p)):
|
| 122 |
+
if tar.endswith(".npz"):
|
| 123 |
+
m_data = np.load(dir_p+tar, allow_pickle=True)
|
| 124 |
+
betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
|
| 125 |
+
n, c = poses.shape[0], poses.shape[1]
|
| 126 |
+
betas = betas.reshape(1, 300)
|
| 127 |
+
betas = np.tile(betas, (n, 1))
|
| 128 |
+
betas = torch.from_numpy(betas).cuda().float()
|
| 129 |
+
poses = torch.from_numpy(poses.reshape(n, c)).cuda().float()
|
| 130 |
+
exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float()
|
| 131 |
+
trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float()
|
| 132 |
+
max_length = 128
|
| 133 |
+
s, r = n//max_length, n%max_length
|
| 134 |
+
#print(n, s, r)
|
| 135 |
+
all_tensor = []
|
| 136 |
+
for i in range(s):
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
joints = self.smplx(
|
| 139 |
+
betas=betas[i*max_length:(i+1)*max_length],
|
| 140 |
+
transl=trans[i*max_length:(i+1)*max_length],
|
| 141 |
+
expression=exps[i*max_length:(i+1)*max_length],
|
| 142 |
+
jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
|
| 143 |
+
global_orient=poses[i*max_length:(i+1)*max_length,:3],
|
| 144 |
+
body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
|
| 145 |
+
left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
|
| 146 |
+
right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
|
| 147 |
+
return_verts=True,
|
| 148 |
+
return_joints=True,
|
| 149 |
+
leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
|
| 150 |
+
reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
|
| 151 |
+
)['joints'][:, :55, :].reshape(max_length, 55*3)
|
| 152 |
+
all_tensor.append(joints)
|
| 153 |
+
if r != 0:
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
joints = self.smplx(
|
| 156 |
+
betas=betas[s*max_length:s*max_length+r],
|
| 157 |
+
transl=trans[s*max_length:s*max_length+r],
|
| 158 |
+
expression=exps[s*max_length:s*max_length+r],
|
| 159 |
+
jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
|
| 160 |
+
global_orient=poses[s*max_length:s*max_length+r,:3],
|
| 161 |
+
body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
|
| 162 |
+
left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
|
| 163 |
+
right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
|
| 164 |
+
return_verts=True,
|
| 165 |
+
return_joints=True,
|
| 166 |
+
leye_pose=poses[s*max_length:s*max_length+r, 69:72],
|
| 167 |
+
reye_pose=poses[s*max_length:s*max_length+r, 72:75],
|
| 168 |
+
)['joints'][:, :55, :].reshape(r, 55*3)
|
| 169 |
+
all_tensor.append(joints)
|
| 170 |
+
joints = torch.cat(all_tensor, axis=0)
|
| 171 |
+
joints = joints.permute(1, 0)
|
| 172 |
+
dt = 1/30
|
| 173 |
+
# first steps is forward diff (t+1 - t) / dt
|
| 174 |
+
init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
|
| 175 |
+
# middle steps are second order (t+1 - t-1) / 2dt
|
| 176 |
+
middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
|
| 177 |
+
# last step is backward diff (t - t-1) / dt
|
| 178 |
+
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
| 179 |
+
#print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape)
|
| 180 |
+
vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3)
|
| 181 |
+
#print(vel_seq.shape)
|
| 182 |
+
#.permute(1, 0).reshape(n, 55, 3)
|
| 183 |
+
vel_seq_np = vel_seq.cpu().numpy()
|
| 184 |
+
vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55
|
| 185 |
+
all_list.append(vel_joints_np)
|
| 186 |
+
avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55
|
| 187 |
+
np.save(save_path, avg_vel)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_cache(self, preloaded_dir):
|
| 191 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 192 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 193 |
+
logger.info("Creating the dataset cache...")
|
| 194 |
+
if self.args.new_cache:
|
| 195 |
+
if os.path.exists(preloaded_dir):
|
| 196 |
+
shutil.rmtree(preloaded_dir)
|
| 197 |
+
if os.path.exists(preloaded_dir):
|
| 198 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 199 |
+
elif self.loader_type == "test":
|
| 200 |
+
self.cache_generation(
|
| 201 |
+
preloaded_dir, True,
|
| 202 |
+
0, 0,
|
| 203 |
+
is_test=True)
|
| 204 |
+
else:
|
| 205 |
+
self.cache_generation(
|
| 206 |
+
preloaded_dir, self.args.disable_filtering,
|
| 207 |
+
self.args.clean_first_seconds, self.args.clean_final_seconds,
|
| 208 |
+
is_test=False)
|
| 209 |
+
|
| 210 |
+
def __len__(self):
|
| 211 |
+
return self.n_samples
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 215 |
+
# if "wav2vec2" in self.args.audio_rep:
|
| 216 |
+
# self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
|
| 217 |
+
# self.wav2vec_model.feature_extractor._freeze_parameters()
|
| 218 |
+
# self.wav2vec_model = self.wav2vec_model.cuda()
|
| 219 |
+
# self.wav2vec_model.eval()
|
| 220 |
+
|
| 221 |
+
self.n_out_samples = 0
|
| 222 |
+
# create db for samples
|
| 223 |
+
if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
|
| 224 |
+
dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
|
| 225 |
+
n_filtered_out = defaultdict(int)
|
| 226 |
+
|
| 227 |
+
for index, file_name in self.selected_file.iterrows():
|
| 228 |
+
f_name = file_name["id"]
|
| 229 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 230 |
+
pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
|
| 231 |
+
pose_each_file = []
|
| 232 |
+
trans_each_file = []
|
| 233 |
+
shape_each_file = []
|
| 234 |
+
audio_each_file = []
|
| 235 |
+
facial_each_file = []
|
| 236 |
+
word_each_file = []
|
| 237 |
+
emo_each_file = []
|
| 238 |
+
sem_each_file = []
|
| 239 |
+
vid_each_file = []
|
| 240 |
+
id_pose = f_name #1_wayne_0_1_1
|
| 241 |
+
|
| 242 |
+
logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
|
| 243 |
+
if "smplx" in self.args.pose_rep:
|
| 244 |
+
pose_data = np.load(pose_file, allow_pickle=True)
|
| 245 |
+
assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
|
| 246 |
+
stride = int(30/self.args.pose_fps)
|
| 247 |
+
pose_each_file = pose_data["poses"][::stride] * self.joint_mask
|
| 248 |
+
pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
|
| 249 |
+
# print(pose_each_file.shape)
|
| 250 |
+
trans_each_file = pose_data["trans"][::stride]
|
| 251 |
+
shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
|
| 252 |
+
if self.args.facial_rep is not None:
|
| 253 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 254 |
+
facial_each_file = pose_data["expressions"][::stride]
|
| 255 |
+
if self.args.facial_norm:
|
| 256 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 257 |
+
|
| 258 |
+
else:
|
| 259 |
+
assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
|
| 260 |
+
stride = int(120/self.args.pose_fps)
|
| 261 |
+
with open(pose_file, "r") as pose_data:
|
| 262 |
+
for j, line in enumerate(pose_data.readlines()):
|
| 263 |
+
if j < 431: continue
|
| 264 |
+
if j%stride != 0:continue
|
| 265 |
+
data = np.fromstring(line, dtype=float, sep=" ")
|
| 266 |
+
rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ")
|
| 267 |
+
rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3)
|
| 268 |
+
rot_data = rot_data.numpy() * self.joint_mask
|
| 269 |
+
|
| 270 |
+
pose_each_file.append(rot_data)
|
| 271 |
+
trans_each_file.append(data[:3])
|
| 272 |
+
|
| 273 |
+
pose_each_file = np.array(pose_each_file)
|
| 274 |
+
# print(pose_each_file.shape)
|
| 275 |
+
trans_each_file = np.array(trans_each_file)
|
| 276 |
+
shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 277 |
+
if self.args.facial_rep is not None:
|
| 278 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 279 |
+
facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json")
|
| 280 |
+
assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
|
| 281 |
+
stride = int(60/self.args.pose_fps)
|
| 282 |
+
if not os.path.exists(facial_file):
|
| 283 |
+
logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #")
|
| 284 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 285 |
+
continue
|
| 286 |
+
with open(facial_file, 'r') as facial_data_file:
|
| 287 |
+
facial_data = json.load(facial_data_file)
|
| 288 |
+
for j, frame_data in enumerate(facial_data['frames']):
|
| 289 |
+
if j%stride != 0:continue
|
| 290 |
+
facial_each_file.append(frame_data['weights'])
|
| 291 |
+
facial_each_file = np.array(facial_each_file)
|
| 292 |
+
if self.args.facial_norm:
|
| 293 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 294 |
+
|
| 295 |
+
if self.args.id_rep is not None:
|
| 296 |
+
vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 297 |
+
|
| 298 |
+
if self.args.audio_rep is not None:
|
| 299 |
+
logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
|
| 300 |
+
audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
|
| 301 |
+
if not os.path.exists(audio_file):
|
| 302 |
+
logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
|
| 303 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 304 |
+
continue
|
| 305 |
+
audio_each_file, sr = librosa.load(audio_file)
|
| 306 |
+
audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
|
| 307 |
+
if self.args.audio_rep == "onset+amplitude":
|
| 308 |
+
from numpy.lib import stride_tricks
|
| 309 |
+
frame_length = 1024
|
| 310 |
+
# hop_length = 512
|
| 311 |
+
shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
|
| 312 |
+
strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
|
| 313 |
+
rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
|
| 314 |
+
amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
|
| 315 |
+
# pad the last frame_length-1 samples
|
| 316 |
+
amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
|
| 317 |
+
audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
|
| 318 |
+
onset_array = np.zeros(len(audio_each_file), dtype=float)
|
| 319 |
+
onset_array[audio_onset_f] = 1.0
|
| 320 |
+
# print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
|
| 321 |
+
audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
|
| 322 |
+
elif self.args.audio_rep == "mfcc":
|
| 323 |
+
audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps))
|
| 324 |
+
audio_each_file = audio_each_file.transpose(1, 0)
|
| 325 |
+
# print(audio_each_file.shape, pose_each_file.shape)
|
| 326 |
+
if self.args.audio_norm and self.args.audio_rep == "wave16k":
|
| 327 |
+
audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
|
| 328 |
+
# print(audio_each_file.shape)
|
| 329 |
+
time_offset = 0
|
| 330 |
+
if self.args.word_rep is not None:
|
| 331 |
+
logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
|
| 332 |
+
word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid"
|
| 333 |
+
if not os.path.exists(word_file):
|
| 334 |
+
logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
|
| 335 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 336 |
+
continue
|
| 337 |
+
tgrid = tg.TextGrid.fromFile(word_file)
|
| 338 |
+
if self.args.t_pre_encoder == "bert":
|
| 339 |
+
from transformers import AutoTokenizer, BertModel
|
| 340 |
+
tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True)
|
| 341 |
+
model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval()
|
| 342 |
+
list_word = []
|
| 343 |
+
all_hidden = []
|
| 344 |
+
max_len = 400
|
| 345 |
+
last = 0
|
| 346 |
+
word_token_mapping = []
|
| 347 |
+
first = True
|
| 348 |
+
for i, word in enumerate(tgrid[0]):
|
| 349 |
+
last = i
|
| 350 |
+
if (i%max_len != 0) or (i==0):
|
| 351 |
+
if word.mark == "":
|
| 352 |
+
list_word.append(".")
|
| 353 |
+
else:
|
| 354 |
+
list_word.append(word.mark)
|
| 355 |
+
else:
|
| 356 |
+
max_counter = max_len
|
| 357 |
+
str_word = ' '.join(map(str, list_word))
|
| 358 |
+
if first:
|
| 359 |
+
global_len = 0
|
| 360 |
+
end = -1
|
| 361 |
+
offset_word = []
|
| 362 |
+
for k, wordvalue in enumerate(list_word):
|
| 363 |
+
start = end+1
|
| 364 |
+
end = start+len(wordvalue)
|
| 365 |
+
offset_word.append((start, end))
|
| 366 |
+
#print(offset_word)
|
| 367 |
+
token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
|
| 368 |
+
#print(token_scan)
|
| 369 |
+
for start, end in offset_word:
|
| 370 |
+
sub_mapping = []
|
| 371 |
+
for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
|
| 372 |
+
if int(start) <= int(start_t) and int(end_t) <= int(end):
|
| 373 |
+
#print(i+global_len)
|
| 374 |
+
sub_mapping.append(i+global_len)
|
| 375 |
+
word_token_mapping.append(sub_mapping)
|
| 376 |
+
#print(len(word_token_mapping))
|
| 377 |
+
global_len = word_token_mapping[-1][-1] + 1
|
| 378 |
+
list_word = []
|
| 379 |
+
if word.mark == "":
|
| 380 |
+
list_word.append(".")
|
| 381 |
+
else:
|
| 382 |
+
list_word.append(word.mark)
|
| 383 |
+
|
| 384 |
+
with torch.no_grad():
|
| 385 |
+
inputs = tokenizer(str_word, return_tensors="pt")
|
| 386 |
+
outputs = model(**inputs)
|
| 387 |
+
last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
|
| 388 |
+
all_hidden.append(last_hidden_states)
|
| 389 |
+
|
| 390 |
+
#list_word = list_word[:10]
|
| 391 |
+
if list_word == []:
|
| 392 |
+
pass
|
| 393 |
+
else:
|
| 394 |
+
if first:
|
| 395 |
+
global_len = 0
|
| 396 |
+
str_word = ' '.join(map(str, list_word))
|
| 397 |
+
end = -1
|
| 398 |
+
offset_word = []
|
| 399 |
+
for k, wordvalue in enumerate(list_word):
|
| 400 |
+
start = end+1
|
| 401 |
+
end = start+len(wordvalue)
|
| 402 |
+
offset_word.append((start, end))
|
| 403 |
+
#print(offset_word)
|
| 404 |
+
token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
|
| 405 |
+
#print(token_scan)
|
| 406 |
+
for start, end in offset_word:
|
| 407 |
+
sub_mapping = []
|
| 408 |
+
for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
|
| 409 |
+
if int(start) <= int(start_t) and int(end_t) <= int(end):
|
| 410 |
+
sub_mapping.append(i+global_len)
|
| 411 |
+
#print(sub_mapping)
|
| 412 |
+
word_token_mapping.append(sub_mapping)
|
| 413 |
+
#print(len(word_token_mapping))
|
| 414 |
+
with torch.no_grad():
|
| 415 |
+
inputs = tokenizer(str_word, return_tensors="pt")
|
| 416 |
+
outputs = model(**inputs)
|
| 417 |
+
last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
|
| 418 |
+
all_hidden.append(last_hidden_states)
|
| 419 |
+
last_hidden_states = np.concatenate(all_hidden, axis=0)
|
| 420 |
+
|
| 421 |
+
for i in range(pose_each_file.shape[0]):
|
| 422 |
+
found_flag = False
|
| 423 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 424 |
+
j_last = 0
|
| 425 |
+
for j, word in enumerate(tgrid[0]):
|
| 426 |
+
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
|
| 427 |
+
if word_s<=current_time and current_time<=word_e:
|
| 428 |
+
if self.args.word_cache and self.args.t_pre_encoder == 'bert':
|
| 429 |
+
mapping_index = word_token_mapping[j]
|
| 430 |
+
#print(mapping_index, word_s, word_e)
|
| 431 |
+
s_t = np.linspace(word_s, word_e, len(mapping_index)+1)
|
| 432 |
+
#print(s_t)
|
| 433 |
+
for tt, t_sep in enumerate(s_t[1:]):
|
| 434 |
+
if current_time <= t_sep:
|
| 435 |
+
#if len(mapping_index) > 1: print(mapping_index[tt])
|
| 436 |
+
word_each_file.append(last_hidden_states[mapping_index[tt]])
|
| 437 |
+
break
|
| 438 |
+
else:
|
| 439 |
+
if word_n == " ":
|
| 440 |
+
word_each_file.append(self.lang_model.PAD_token)
|
| 441 |
+
else:
|
| 442 |
+
word_each_file.append(self.lang_model.get_word_index(word_n))
|
| 443 |
+
found_flag = True
|
| 444 |
+
j_last = j
|
| 445 |
+
break
|
| 446 |
+
else: continue
|
| 447 |
+
if not found_flag:
|
| 448 |
+
if self.args.word_cache and self.args.t_pre_encoder == 'bert':
|
| 449 |
+
word_each_file.append(last_hidden_states[j_last])
|
| 450 |
+
else:
|
| 451 |
+
word_each_file.append(self.lang_model.UNK_token)
|
| 452 |
+
word_each_file = np.array(word_each_file)
|
| 453 |
+
#print(word_each_file.shape)
|
| 454 |
+
|
| 455 |
+
if self.args.emo_rep is not None:
|
| 456 |
+
logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
|
| 457 |
+
rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
|
| 458 |
+
if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
|
| 459 |
+
if start >= 1 and start <= 64:
|
| 460 |
+
score = 0
|
| 461 |
+
elif start >= 65 and start <= 72:
|
| 462 |
+
score = 1
|
| 463 |
+
elif start >= 73 and start <= 80:
|
| 464 |
+
score = 2
|
| 465 |
+
elif start >= 81 and start <= 86:
|
| 466 |
+
score = 3
|
| 467 |
+
elif start >= 87 and start <= 94:
|
| 468 |
+
score = 4
|
| 469 |
+
elif start >= 95 and start <= 102:
|
| 470 |
+
score = 5
|
| 471 |
+
elif start >= 103 and start <= 110:
|
| 472 |
+
score = 6
|
| 473 |
+
elif start >= 111 and start <= 118:
|
| 474 |
+
score = 7
|
| 475 |
+
else: pass
|
| 476 |
+
else:
|
| 477 |
+
# you may denote as unknown in the future
|
| 478 |
+
score = 0
|
| 479 |
+
emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 480 |
+
#print(emo_each_file)
|
| 481 |
+
|
| 482 |
+
if self.args.sem_rep is not None:
|
| 483 |
+
logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
|
| 484 |
+
sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
|
| 485 |
+
sem_all = pd.read_csv(sem_file,
|
| 486 |
+
sep='\t',
|
| 487 |
+
names=["name", "start_time", "end_time", "duration", "score", "keywords"])
|
| 488 |
+
# we adopt motion-level semantic score here.
|
| 489 |
+
for i in range(pose_each_file.shape[0]):
|
| 490 |
+
found_flag = False
|
| 491 |
+
for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
|
| 492 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 493 |
+
if start<=current_time and current_time<=end:
|
| 494 |
+
sem_each_file.append(score)
|
| 495 |
+
found_flag=True
|
| 496 |
+
break
|
| 497 |
+
else: continue
|
| 498 |
+
if not found_flag: sem_each_file.append(0.)
|
| 499 |
+
sem_each_file = np.array(sem_each_file)
|
| 500 |
+
#print(sem_each_file)
|
| 501 |
+
|
| 502 |
+
filtered_result = self._sample_from_clip(
|
| 503 |
+
dst_lmdb_env,
|
| 504 |
+
audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
|
| 505 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 506 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 507 |
+
)
|
| 508 |
+
for type in filtered_result.keys():
|
| 509 |
+
n_filtered_out[type] += filtered_result[type]
|
| 510 |
+
|
| 511 |
+
with dst_lmdb_env.begin() as txn:
|
| 512 |
+
logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
|
| 513 |
+
n_total_filtered = 0
|
| 514 |
+
for type, n_filtered in n_filtered_out.items():
|
| 515 |
+
logger.info("{}: {}".format(type, n_filtered))
|
| 516 |
+
n_total_filtered += n_filtered
|
| 517 |
+
logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
|
| 518 |
+
n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
|
| 519 |
+
dst_lmdb_env.sync()
|
| 520 |
+
dst_lmdb_env.close()
|
| 521 |
+
|
| 522 |
+
def _sample_from_clip(
|
| 523 |
+
self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
|
| 524 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 525 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 526 |
+
):
|
| 527 |
+
"""
|
| 528 |
+
for data cleaning, we ignore the data for first and final n s
|
| 529 |
+
for test, we return all data
|
| 530 |
+
"""
|
| 531 |
+
# audio_start = int(self.alignment[0] * self.args.audio_fps)
|
| 532 |
+
# pose_start = int(self.alignment[1] * self.args.pose_fps)
|
| 533 |
+
#logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
|
| 534 |
+
# audio_each_file = audio_each_file[audio_start:]
|
| 535 |
+
# pose_each_file = pose_each_file[pose_start:]
|
| 536 |
+
# trans_each_file =
|
| 537 |
+
#logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
|
| 538 |
+
#print(pose_each_file.shape)
|
| 539 |
+
round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
|
| 540 |
+
#print(round_seconds_skeleton)
|
| 541 |
+
if audio_each_file != []:
|
| 542 |
+
if self.args.audio_rep != "wave16k":
|
| 543 |
+
round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
|
| 544 |
+
elif self.args.audio_rep == "mfcc":
|
| 545 |
+
round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps
|
| 546 |
+
else:
|
| 547 |
+
round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr
|
| 548 |
+
if facial_each_file != []:
|
| 549 |
+
round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
|
| 550 |
+
logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
|
| 551 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 552 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 553 |
+
if round_seconds_skeleton != max_round:
|
| 554 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 555 |
+
else:
|
| 556 |
+
logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
|
| 557 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
|
| 558 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton)
|
| 559 |
+
if round_seconds_skeleton != max_round:
|
| 560 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 561 |
+
|
| 562 |
+
clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
|
| 563 |
+
clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
|
| 564 |
+
clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
for ratio in self.args.multi_length_training:
|
| 568 |
+
if is_test:# stride = length for test
|
| 569 |
+
cut_length = clip_e_f_pose - clip_s_f_pose
|
| 570 |
+
self.args.stride = cut_length
|
| 571 |
+
self.max_length = cut_length
|
| 572 |
+
else:
|
| 573 |
+
self.args.stride = int(ratio*self.ori_stride)
|
| 574 |
+
cut_length = int(self.ori_length*ratio)
|
| 575 |
+
|
| 576 |
+
num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
|
| 577 |
+
logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
|
| 578 |
+
logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
|
| 579 |
+
|
| 580 |
+
if audio_each_file != []:
|
| 581 |
+
audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
|
| 582 |
+
"""
|
| 583 |
+
for audio sr = 16000, fps = 15, pose_length = 34,
|
| 584 |
+
audio short length = 36266.7 -> 36266
|
| 585 |
+
this error is fine.
|
| 586 |
+
"""
|
| 587 |
+
logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
|
| 588 |
+
|
| 589 |
+
n_filtered_out = defaultdict(int)
|
| 590 |
+
sample_pose_list = []
|
| 591 |
+
sample_audio_list = []
|
| 592 |
+
sample_facial_list = []
|
| 593 |
+
sample_shape_list = []
|
| 594 |
+
sample_word_list = []
|
| 595 |
+
sample_emo_list = []
|
| 596 |
+
sample_sem_list = []
|
| 597 |
+
sample_vid_list = []
|
| 598 |
+
sample_trans_list = []
|
| 599 |
+
|
| 600 |
+
for i in range(num_subdivision): # cut into around 2s chip, (self npose)
|
| 601 |
+
start_idx = clip_s_f_pose + i * self.args.stride
|
| 602 |
+
fin_idx = start_idx + cut_length
|
| 603 |
+
sample_pose = pose_each_file[start_idx:fin_idx]
|
| 604 |
+
sample_trans = trans_each_file[start_idx:fin_idx]
|
| 605 |
+
sample_shape = shape_each_file[start_idx:fin_idx]
|
| 606 |
+
# print(sample_pose.shape)
|
| 607 |
+
if self.args.audio_rep is not None:
|
| 608 |
+
audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
|
| 609 |
+
audio_end = audio_start + audio_short_length
|
| 610 |
+
sample_audio = audio_each_file[audio_start:audio_end]
|
| 611 |
+
else:
|
| 612 |
+
sample_audio = np.array([-1])
|
| 613 |
+
sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
|
| 614 |
+
sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
|
| 615 |
+
sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
|
| 616 |
+
sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
|
| 617 |
+
sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
|
| 618 |
+
|
| 619 |
+
if sample_pose.any() != None:
|
| 620 |
+
# filtering motion skeleton data
|
| 621 |
+
sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
|
| 622 |
+
is_correct_motion = (sample_pose != [])
|
| 623 |
+
if is_correct_motion or disable_filtering:
|
| 624 |
+
sample_pose_list.append(sample_pose)
|
| 625 |
+
sample_audio_list.append(sample_audio)
|
| 626 |
+
sample_facial_list.append(sample_facial)
|
| 627 |
+
sample_shape_list.append(sample_shape)
|
| 628 |
+
sample_word_list.append(sample_word)
|
| 629 |
+
sample_vid_list.append(sample_vid)
|
| 630 |
+
sample_emo_list.append(sample_emo)
|
| 631 |
+
sample_sem_list.append(sample_sem)
|
| 632 |
+
sample_trans_list.append(sample_trans)
|
| 633 |
+
else:
|
| 634 |
+
n_filtered_out[filtering_message] += 1
|
| 635 |
+
|
| 636 |
+
if len(sample_pose_list) > 0:
|
| 637 |
+
with dst_lmdb_env.begin(write=True) as txn:
|
| 638 |
+
for pose, audio, facial, shape, word, vid, emo, sem, trans in zip(
|
| 639 |
+
sample_pose_list,
|
| 640 |
+
sample_audio_list,
|
| 641 |
+
sample_facial_list,
|
| 642 |
+
sample_shape_list,
|
| 643 |
+
sample_word_list,
|
| 644 |
+
sample_vid_list,
|
| 645 |
+
sample_emo_list,
|
| 646 |
+
sample_sem_list,
|
| 647 |
+
sample_trans_list,):
|
| 648 |
+
k = "{:005}".format(self.n_out_samples).encode("ascii")
|
| 649 |
+
v = [pose, audio, facial, shape, word, emo, sem, vid, trans]
|
| 650 |
+
v = pickle.dumps(v,5)
|
| 651 |
+
txn.put(k, v)
|
| 652 |
+
self.n_out_samples += 1
|
| 653 |
+
return n_filtered_out
|
| 654 |
+
|
| 655 |
+
def __getitem__(self, idx):
|
| 656 |
+
with self.lmdb_env.begin(write=False) as txn:
|
| 657 |
+
key = "{:005}".format(idx).encode("ascii")
|
| 658 |
+
sample = txn.get(key)
|
| 659 |
+
sample = pickle.loads(sample)
|
| 660 |
+
tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample
|
| 661 |
+
#print(in_shape)
|
| 662 |
+
#vid = torch.from_numpy(vid).int()
|
| 663 |
+
emo = torch.from_numpy(emo).int()
|
| 664 |
+
sem = torch.from_numpy(sem).float()
|
| 665 |
+
in_audio = torch.from_numpy(in_audio).float()
|
| 666 |
+
in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
|
| 667 |
+
if self.loader_type == "test":
|
| 668 |
+
tar_pose = torch.from_numpy(tar_pose).float()
|
| 669 |
+
trans = torch.from_numpy(trans).float()
|
| 670 |
+
in_facial = torch.from_numpy(in_facial).float()
|
| 671 |
+
vid = torch.from_numpy(vid).float()
|
| 672 |
+
in_shape = torch.from_numpy(in_shape).float()
|
| 673 |
+
else:
|
| 674 |
+
in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 675 |
+
trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
|
| 676 |
+
vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
|
| 677 |
+
tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
|
| 678 |
+
in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
|
| 679 |
+
return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans}
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class MotionPreprocessor:
|
| 683 |
+
def __init__(self, skeletons):
|
| 684 |
+
self.skeletons = skeletons
|
| 685 |
+
#self.mean_pose = mean_pose
|
| 686 |
+
self.filtering_message = "PASS"
|
| 687 |
+
|
| 688 |
+
def get(self):
|
| 689 |
+
assert (self.skeletons is not None)
|
| 690 |
+
|
| 691 |
+
# filtering
|
| 692 |
+
if self.skeletons != []:
|
| 693 |
+
if self.check_pose_diff():
|
| 694 |
+
self.skeletons = []
|
| 695 |
+
self.filtering_message = "pose"
|
| 696 |
+
# elif self.check_spine_angle():
|
| 697 |
+
# self.skeletons = []
|
| 698 |
+
# self.filtering_message = "spine angle"
|
| 699 |
+
# elif self.check_static_motion():
|
| 700 |
+
# self.skeletons = []
|
| 701 |
+
# self.filtering_message = "motion"
|
| 702 |
+
|
| 703 |
+
# if self.skeletons != []:
|
| 704 |
+
# self.skeletons = self.skeletons.tolist()
|
| 705 |
+
# for i, frame in enumerate(self.skeletons):
|
| 706 |
+
# assert not np.isnan(self.skeletons[i]).any() # missing joints
|
| 707 |
+
|
| 708 |
+
return self.skeletons, self.filtering_message
|
| 709 |
+
|
| 710 |
+
def check_static_motion(self, verbose=True):
|
| 711 |
+
def get_variance(skeleton, joint_idx):
|
| 712 |
+
wrist_pos = skeleton[:, joint_idx]
|
| 713 |
+
variance = np.sum(np.var(wrist_pos, axis=0))
|
| 714 |
+
return variance
|
| 715 |
+
|
| 716 |
+
left_arm_var = get_variance(self.skeletons, 6)
|
| 717 |
+
right_arm_var = get_variance(self.skeletons, 9)
|
| 718 |
+
|
| 719 |
+
th = 0.0014 # exclude 13110
|
| 720 |
+
# th = 0.002 # exclude 16905
|
| 721 |
+
if left_arm_var < th and right_arm_var < th:
|
| 722 |
+
if verbose:
|
| 723 |
+
print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 724 |
+
return True
|
| 725 |
+
else:
|
| 726 |
+
if verbose:
|
| 727 |
+
print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 728 |
+
return False
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def check_pose_diff(self, verbose=False):
|
| 732 |
+
# diff = np.abs(self.skeletons - self.mean_pose) # 186*1
|
| 733 |
+
# diff = np.mean(diff)
|
| 734 |
+
|
| 735 |
+
# # th = 0.017
|
| 736 |
+
# th = 0.02 #0.02 # exclude 3594
|
| 737 |
+
# if diff < th:
|
| 738 |
+
# if verbose:
|
| 739 |
+
# print("skip - check_pose_diff {:.5f}".format(diff))
|
| 740 |
+
# return True
|
| 741 |
+
# # th = 3.5 #0.02 # exclude 3594
|
| 742 |
+
# # if 3.5 < diff < 5:
|
| 743 |
+
# # if verbose:
|
| 744 |
+
# # print("skip - check_pose_diff {:.5f}".format(diff))
|
| 745 |
+
# # return True
|
| 746 |
+
# else:
|
| 747 |
+
# if verbose:
|
| 748 |
+
# print("pass - check_pose_diff {:.5f}".format(diff))
|
| 749 |
+
return False
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def check_spine_angle(self, verbose=True):
|
| 753 |
+
def angle_between(v1, v2):
|
| 754 |
+
v1_u = v1 / np.linalg.norm(v1)
|
| 755 |
+
v2_u = v2 / np.linalg.norm(v2)
|
| 756 |
+
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
|
| 757 |
+
|
| 758 |
+
angles = []
|
| 759 |
+
for i in range(self.skeletons.shape[0]):
|
| 760 |
+
spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
|
| 761 |
+
angle = angle_between(spine_vec, [0, -1, 0])
|
| 762 |
+
angles.append(angle)
|
| 763 |
+
|
| 764 |
+
if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
|
| 765 |
+
# if np.rad2deg(max(angles)) > 20: # exclude 8270
|
| 766 |
+
if verbose:
|
| 767 |
+
print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
|
| 768 |
+
return True
|
| 769 |
+
else:
|
| 770 |
+
if verbose:
|
| 771 |
+
print("pass - check_spine_angle {:.5f}".format(max(angles)))
|
| 772 |
+
return False
|
dataloaders/beat_sep_lower.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import glob
|
| 10 |
+
import json
|
| 11 |
+
from dataloaders.build_vocab import Vocab
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import pickle
|
| 18 |
+
import smplx
|
| 19 |
+
from .utils.audio_features import process_audio_data
|
| 20 |
+
from .data_tools import joints_list
|
| 21 |
+
from .utils.other_tools import MultiLMDBManager
|
| 22 |
+
from .utils.motion_rep_transfer import process_smplx_motion
|
| 23 |
+
from .utils.mis_features import process_semantic_data, process_emotion_data
|
| 24 |
+
from .utils.text_features import process_word_data
|
| 25 |
+
from .utils.data_sample import sample_from_clip
|
| 26 |
+
import time
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CustomDataset(Dataset):
|
| 30 |
+
def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
|
| 31 |
+
self.args = args
|
| 32 |
+
self.loader_type = loader_type
|
| 33 |
+
|
| 34 |
+
# Set rank safely - handle cases where distributed training is not yet initialized
|
| 35 |
+
try:
|
| 36 |
+
if torch.distributed.is_initialized():
|
| 37 |
+
self.rank = torch.distributed.get_rank()
|
| 38 |
+
else:
|
| 39 |
+
self.rank = 0
|
| 40 |
+
except:
|
| 41 |
+
self.rank = 0
|
| 42 |
+
|
| 43 |
+
self.ori_stride = self.args.stride
|
| 44 |
+
self.ori_length = self.args.pose_length
|
| 45 |
+
|
| 46 |
+
# Initialize basic parameters
|
| 47 |
+
self.ori_stride = self.args.stride
|
| 48 |
+
self.ori_length = self.args.pose_length
|
| 49 |
+
self.alignment = [0,0] # for trinity
|
| 50 |
+
|
| 51 |
+
"""Initialize SMPLX model."""
|
| 52 |
+
self.smplx = smplx.create(
|
| 53 |
+
self.args.data_path_1+"smplx_models/",
|
| 54 |
+
model_type='smplx',
|
| 55 |
+
gender='NEUTRAL_2020',
|
| 56 |
+
use_face_contour=False,
|
| 57 |
+
num_betas=300,
|
| 58 |
+
num_expression_coeffs=100,
|
| 59 |
+
ext='npz',
|
| 60 |
+
use_pca=False,
|
| 61 |
+
).cuda().eval()
|
| 62 |
+
|
| 63 |
+
if self.args.word_rep is not None:
|
| 64 |
+
with open(f"{self.args.data_path}weights/vocab.pkl", 'rb') as f:
|
| 65 |
+
self.lang_model = pickle.load(f)
|
| 66 |
+
|
| 67 |
+
# Load and process split rules
|
| 68 |
+
self._process_split_rules()
|
| 69 |
+
|
| 70 |
+
# Initialize data directories and lengths
|
| 71 |
+
self._init_data_paths()
|
| 72 |
+
|
| 73 |
+
if self.args.beat_align:
|
| 74 |
+
if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
|
| 75 |
+
self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 76 |
+
self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 77 |
+
|
| 78 |
+
# Build or load cache
|
| 79 |
+
self._init_cache(build_cache)
|
| 80 |
+
|
| 81 |
+
def _process_split_rules(self):
|
| 82 |
+
"""Process dataset split rules."""
|
| 83 |
+
split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv")
|
| 84 |
+
self.selected_file = split_rule.loc[
|
| 85 |
+
(split_rule['type'] == self.loader_type) &
|
| 86 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
if self.args.additional_data and self.loader_type == 'train':
|
| 90 |
+
split_b = split_rule.loc[
|
| 91 |
+
(split_rule['type'] == 'additional') &
|
| 92 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 93 |
+
]
|
| 94 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 95 |
+
|
| 96 |
+
if self.selected_file.empty:
|
| 97 |
+
logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 98 |
+
self.selected_file = split_rule.loc[
|
| 99 |
+
(split_rule['type'] == 'train') &
|
| 100 |
+
(split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))
|
| 101 |
+
]
|
| 102 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 103 |
+
|
| 104 |
+
def _init_data_paths(self):
|
| 105 |
+
"""Initialize data directories and lengths."""
|
| 106 |
+
self.data_dir = self.args.data_path
|
| 107 |
+
|
| 108 |
+
if self.loader_type == "test":
|
| 109 |
+
self.args.multi_length_training = [1.0]
|
| 110 |
+
|
| 111 |
+
self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1])
|
| 112 |
+
self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr)
|
| 113 |
+
|
| 114 |
+
if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr:
|
| 115 |
+
self.max_audio_pre_len = self.args.test_length * self.args.audio_sr
|
| 116 |
+
|
| 117 |
+
if self.args.test_clip and self.loader_type == "test":
|
| 118 |
+
self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache"
|
| 119 |
+
else:
|
| 120 |
+
self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache"
|
| 121 |
+
|
| 122 |
+
def _init_cache(self, build_cache):
|
| 123 |
+
"""Initialize or build cache."""
|
| 124 |
+
self.lmdb_envs = {}
|
| 125 |
+
self.mapping_data = None
|
| 126 |
+
|
| 127 |
+
if build_cache and self.rank == 0:
|
| 128 |
+
self.build_cache(self.preloaded_dir)
|
| 129 |
+
|
| 130 |
+
# In DDP mode, ensure all processes wait for cache building to complete
|
| 131 |
+
if torch.distributed.is_initialized():
|
| 132 |
+
torch.distributed.barrier()
|
| 133 |
+
|
| 134 |
+
# Try to regenerate cache if corrupted (only on rank 0 to avoid race conditions)
|
| 135 |
+
if self.rank == 0:
|
| 136 |
+
self.regenerate_cache_if_corrupted()
|
| 137 |
+
|
| 138 |
+
# Wait for cache regeneration to complete
|
| 139 |
+
if torch.distributed.is_initialized():
|
| 140 |
+
torch.distributed.barrier()
|
| 141 |
+
|
| 142 |
+
self.load_db_mapping()
|
| 143 |
+
|
| 144 |
+
def build_cache(self, preloaded_dir):
|
| 145 |
+
"""Build the dataset cache."""
|
| 146 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 147 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 148 |
+
logger.info("Creating the dataset cache...")
|
| 149 |
+
|
| 150 |
+
if self.args.new_cache and os.path.exists(preloaded_dir):
|
| 151 |
+
shutil.rmtree(preloaded_dir)
|
| 152 |
+
|
| 153 |
+
if os.path.exists(preloaded_dir):
|
| 154 |
+
# if the dir is empty, that means we still need to build the cache
|
| 155 |
+
if not os.listdir(preloaded_dir):
|
| 156 |
+
self.cache_generation(
|
| 157 |
+
preloaded_dir,
|
| 158 |
+
self.args.disable_filtering,
|
| 159 |
+
self.args.clean_first_seconds,
|
| 160 |
+
self.args.clean_final_seconds,
|
| 161 |
+
is_test=False
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 165 |
+
|
| 166 |
+
elif self.loader_type == "test":
|
| 167 |
+
self.cache_generation(preloaded_dir, True, 0, 0, is_test=True)
|
| 168 |
+
else:
|
| 169 |
+
self.cache_generation(
|
| 170 |
+
preloaded_dir,
|
| 171 |
+
self.args.disable_filtering,
|
| 172 |
+
self.args.clean_first_seconds,
|
| 173 |
+
self.args.clean_final_seconds,
|
| 174 |
+
is_test=False
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 178 |
+
"""Generate cache for the dataset."""
|
| 179 |
+
if not os.path.exists(out_lmdb_dir):
|
| 180 |
+
os.makedirs(out_lmdb_dir)
|
| 181 |
+
|
| 182 |
+
# Initialize the multi-LMDB manager
|
| 183 |
+
lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024)
|
| 184 |
+
|
| 185 |
+
self.n_out_samples = 0
|
| 186 |
+
n_filtered_out = defaultdict(int)
|
| 187 |
+
|
| 188 |
+
for index, file_name in self.selected_file.iterrows():
|
| 189 |
+
f_name = file_name["id"]
|
| 190 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 191 |
+
pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext)
|
| 192 |
+
|
| 193 |
+
# Process data
|
| 194 |
+
data = self._process_file_data(f_name, pose_file, ext)
|
| 195 |
+
if data is None:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
# Sample from clip
|
| 199 |
+
filtered_result, self.n_out_samples = sample_from_clip(
|
| 200 |
+
lmdb_manager=lmdb_manager,
|
| 201 |
+
audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"),
|
| 202 |
+
audio_each_file=data['audio'],
|
| 203 |
+
pose_each_file=data['pose'],
|
| 204 |
+
trans_each_file=data['trans'],
|
| 205 |
+
trans_v_each_file=data['trans_v'],
|
| 206 |
+
shape_each_file=data['shape'],
|
| 207 |
+
facial_each_file=data['facial'],
|
| 208 |
+
word_each_file=data['word'],
|
| 209 |
+
vid_each_file=data['vid'],
|
| 210 |
+
emo_each_file=data['emo'],
|
| 211 |
+
sem_each_file=data['sem'],
|
| 212 |
+
args=self.args,
|
| 213 |
+
ori_stride=self.ori_stride,
|
| 214 |
+
ori_length=self.ori_length,
|
| 215 |
+
disable_filtering=disable_filtering,
|
| 216 |
+
clean_first_seconds=clean_first_seconds,
|
| 217 |
+
clean_final_seconds=clean_final_seconds,
|
| 218 |
+
is_test=is_test,
|
| 219 |
+
n_out_samples=self.n_out_samples
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
for type_key in filtered_result:
|
| 223 |
+
n_filtered_out[type_key] += filtered_result[type_key]
|
| 224 |
+
|
| 225 |
+
lmdb_manager.close()
|
| 226 |
+
|
| 227 |
+
def _process_file_data(self, f_name, pose_file, ext):
|
| 228 |
+
"""Process all data for a single file."""
|
| 229 |
+
data = {
|
| 230 |
+
'pose': None, 'trans': None, 'trans_v': None, 'shape': None,
|
| 231 |
+
'audio': None, 'facial': None, 'word': None, 'emo': None,
|
| 232 |
+
'sem': None, 'vid': None
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# Process motion data
|
| 236 |
+
logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue"))
|
| 237 |
+
if "smplx" in self.args.pose_rep:
|
| 238 |
+
motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.")
|
| 241 |
+
|
| 242 |
+
if motion_data is None:
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
data.update(motion_data)
|
| 246 |
+
|
| 247 |
+
# Process speaker ID
|
| 248 |
+
if self.args.id_rep is not None:
|
| 249 |
+
speaker_id = int(f_name.split("_")[0]) - 1
|
| 250 |
+
data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0)
|
| 251 |
+
else:
|
| 252 |
+
data['vid'] = np.array([-1])
|
| 253 |
+
|
| 254 |
+
# Process audio if needed
|
| 255 |
+
if self.args.audio_rep is not None:
|
| 256 |
+
audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
|
| 257 |
+
data = process_audio_data(audio_file, self.args, data, f_name, self.selected_file)
|
| 258 |
+
if data is None:
|
| 259 |
+
return None
|
| 260 |
+
|
| 261 |
+
# Process emotion if needed
|
| 262 |
+
if self.args.emo_rep is not None:
|
| 263 |
+
data = process_emotion_data(f_name, data, self.args)
|
| 264 |
+
if data is None:
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
# Process word data if needed
|
| 268 |
+
if self.args.word_rep is not None:
|
| 269 |
+
word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid"
|
| 270 |
+
data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file, self.lang_model)
|
| 271 |
+
if data is None:
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
# Process semantic data if needed
|
| 275 |
+
if self.args.sem_rep is not None:
|
| 276 |
+
sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt"
|
| 277 |
+
data = process_semantic_data(sem_file, self.args, data, f_name)
|
| 278 |
+
if data is None:
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
return data
|
| 282 |
+
|
| 283 |
+
def load_db_mapping(self):
|
| 284 |
+
"""Load database mapping from file."""
|
| 285 |
+
mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
|
| 286 |
+
backup_path = os.path.join(self.preloaded_dir, "sample_db_mapping_backup.pkl")
|
| 287 |
+
|
| 288 |
+
# Check if file exists and is readable
|
| 289 |
+
if not os.path.exists(mapping_path):
|
| 290 |
+
raise FileNotFoundError(f"Mapping file not found: {mapping_path}")
|
| 291 |
+
|
| 292 |
+
# Check file size to ensure it's not empty
|
| 293 |
+
file_size = os.path.getsize(mapping_path)
|
| 294 |
+
if file_size == 0:
|
| 295 |
+
raise ValueError(f"Mapping file is empty: {mapping_path}")
|
| 296 |
+
|
| 297 |
+
print(f"Loading mapping file: {mapping_path} (size: {file_size} bytes)")
|
| 298 |
+
|
| 299 |
+
# Add error handling and retry logic for pickle loading
|
| 300 |
+
max_retries = 3
|
| 301 |
+
for attempt in range(max_retries):
|
| 302 |
+
try:
|
| 303 |
+
with open(mapping_path, 'rb') as f:
|
| 304 |
+
self.mapping_data = pickle.load(f)
|
| 305 |
+
print(f"Successfully loaded mapping data with {len(self.mapping_data.get('mapping', []))} samples")
|
| 306 |
+
break
|
| 307 |
+
except (EOFError, pickle.UnpicklingError) as e:
|
| 308 |
+
if attempt < max_retries - 1:
|
| 309 |
+
print(f"Warning: Failed to load pickle file (attempt {attempt + 1}/{max_retries}): {e}")
|
| 310 |
+
print(f"File path: {mapping_path}")
|
| 311 |
+
|
| 312 |
+
# Try backup file if main file is corrupted
|
| 313 |
+
if os.path.exists(backup_path) and os.path.getsize(backup_path) > 0:
|
| 314 |
+
print("Trying backup file...")
|
| 315 |
+
try:
|
| 316 |
+
with open(backup_path, 'rb') as f:
|
| 317 |
+
self.mapping_data = pickle.load(f)
|
| 318 |
+
print(f"Successfully loaded mapping data from backup with {len(self.mapping_data.get('mapping', []))} samples")
|
| 319 |
+
break
|
| 320 |
+
except Exception as backup_e:
|
| 321 |
+
print(f"Backup file also failed: {backup_e}")
|
| 322 |
+
|
| 323 |
+
print("Retrying...")
|
| 324 |
+
time.sleep(1) # Wait a bit before retrying
|
| 325 |
+
else:
|
| 326 |
+
print(f"Error: Failed to load pickle file after {max_retries} attempts: {e}")
|
| 327 |
+
print(f"File path: {mapping_path}")
|
| 328 |
+
print("Please check if the file is corrupted or incomplete.")
|
| 329 |
+
print("You may need to regenerate the cache files.")
|
| 330 |
+
raise
|
| 331 |
+
|
| 332 |
+
# Update paths from test to test_clip if needed
|
| 333 |
+
if self.loader_type == "test" and self.args.test_clip:
|
| 334 |
+
updated_paths = []
|
| 335 |
+
for path in self.mapping_data['db_paths']:
|
| 336 |
+
updated_path = path.replace("test/", "test_clip/")
|
| 337 |
+
updated_paths.append(updated_path)
|
| 338 |
+
self.mapping_data['db_paths'] = updated_paths
|
| 339 |
+
|
| 340 |
+
# In DDP mode, avoid modifying shared files to prevent race conditions
|
| 341 |
+
# Instead, just update the in-memory data
|
| 342 |
+
print(f"Updated test paths for test_clip mode (avoiding file modification in DDP)")
|
| 343 |
+
|
| 344 |
+
self.n_samples = len(self.mapping_data['mapping'])
|
| 345 |
+
|
| 346 |
+
def get_lmdb_env(self, db_idx):
|
| 347 |
+
"""Get LMDB environment for given database index."""
|
| 348 |
+
if db_idx not in self.lmdb_envs:
|
| 349 |
+
db_path = self.mapping_data['db_paths'][db_idx]
|
| 350 |
+
self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False)
|
| 351 |
+
return self.lmdb_envs[db_idx]
|
| 352 |
+
|
| 353 |
+
def __len__(self):
|
| 354 |
+
"""Return the total number of samples in the dataset."""
|
| 355 |
+
return self.n_samples
|
| 356 |
+
|
| 357 |
+
def __getitem__(self, idx):
|
| 358 |
+
"""Get a single sample from the dataset."""
|
| 359 |
+
db_idx = self.mapping_data['mapping'][idx]
|
| 360 |
+
lmdb_env = self.get_lmdb_env(db_idx)
|
| 361 |
+
|
| 362 |
+
with lmdb_env.begin(write=False) as txn:
|
| 363 |
+
key = "{:008d}".format(idx).encode("ascii")
|
| 364 |
+
sample = txn.get(key)
|
| 365 |
+
sample = pickle.loads(sample)
|
| 366 |
+
|
| 367 |
+
tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v, audio_name = sample
|
| 368 |
+
|
| 369 |
+
# Convert data to tensors with appropriate types
|
| 370 |
+
processed_data = self._convert_to_tensors(
|
| 371 |
+
tar_pose, in_audio, in_facial, in_shape, in_word,
|
| 372 |
+
emo, sem, vid, trans, trans_v
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
processed_data['audio_name'] = audio_name
|
| 376 |
+
return processed_data
|
| 377 |
+
|
| 378 |
+
def _convert_to_tensors(self, tar_pose, in_audio, in_facial, in_shape, in_word,
|
| 379 |
+
emo, sem, vid, trans, trans_v):
|
| 380 |
+
"""Convert numpy arrays to tensors with appropriate types."""
|
| 381 |
+
data = {
|
| 382 |
+
'emo': torch.from_numpy(emo).int(),
|
| 383 |
+
'sem': torch.from_numpy(sem).float(),
|
| 384 |
+
'audio_onset': torch.from_numpy(in_audio).float(),
|
| 385 |
+
'word': torch.from_numpy(in_word).int()
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
if self.loader_type == "test":
|
| 389 |
+
data.update({
|
| 390 |
+
'pose': torch.from_numpy(tar_pose).float(),
|
| 391 |
+
'trans': torch.from_numpy(trans).float(),
|
| 392 |
+
'trans_v': torch.from_numpy(trans_v).float(),
|
| 393 |
+
'facial': torch.from_numpy(in_facial).float(),
|
| 394 |
+
'id': torch.from_numpy(vid).float(),
|
| 395 |
+
'beta': torch.from_numpy(in_shape).float()
|
| 396 |
+
})
|
| 397 |
+
else:
|
| 398 |
+
data.update({
|
| 399 |
+
'pose': torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float(),
|
| 400 |
+
'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(),
|
| 401 |
+
'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(),
|
| 402 |
+
'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(),
|
| 403 |
+
'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(),
|
| 404 |
+
'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 405 |
+
})
|
| 406 |
+
|
| 407 |
+
return data
|
| 408 |
+
|
| 409 |
+
def regenerate_cache_if_corrupted(self):
|
| 410 |
+
"""Regenerate cache if the pickle file is corrupted."""
|
| 411 |
+
mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl")
|
| 412 |
+
|
| 413 |
+
if os.path.exists(mapping_path):
|
| 414 |
+
try:
|
| 415 |
+
# Try to load the file to check if it's corrupted
|
| 416 |
+
with open(mapping_path, 'rb') as f:
|
| 417 |
+
test_data = pickle.load(f)
|
| 418 |
+
return False # File is not corrupted
|
| 419 |
+
except (EOFError, pickle.UnpicklingError):
|
| 420 |
+
print(f"Detected corrupted pickle file: {mapping_path}")
|
| 421 |
+
print("Regenerating cache...")
|
| 422 |
+
|
| 423 |
+
# Remove corrupted file
|
| 424 |
+
os.remove(mapping_path)
|
| 425 |
+
|
| 426 |
+
# Regenerate cache
|
| 427 |
+
self.build_cache(self.preloaded_dir)
|
| 428 |
+
return True
|
| 429 |
+
|
| 430 |
+
return False
|
dataloaders/beat_sep_single.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import textgrid as tg
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import glob
|
| 11 |
+
import json
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
#import pyarrow
|
| 18 |
+
import pickle
|
| 19 |
+
import librosa
|
| 20 |
+
import smplx
|
| 21 |
+
|
| 22 |
+
from .build_vocab import Vocab
|
| 23 |
+
from models.utils.wav2vec import Wav2Vec2Model
|
| 24 |
+
from .data_tools import joints_list
|
| 25 |
+
from .utils import rotation_conversions as rc
|
| 26 |
+
from .utils import other_tools
|
| 27 |
+
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class _FallbackLangModel:
|
| 32 |
+
"""Minimal vocabulary that grows on demand for demo/test mode."""
|
| 33 |
+
|
| 34 |
+
def __init__(self) -> None:
|
| 35 |
+
self.PAD_token = 0
|
| 36 |
+
self.UNK_token = 1
|
| 37 |
+
self._word_to_idx = {"<pad>": self.PAD_token, "<unk>": self.UNK_token}
|
| 38 |
+
self.word_embedding_weights = np.zeros((2, 300), dtype=np.float32)
|
| 39 |
+
|
| 40 |
+
def get_word_index(self, word: str) -> int:
|
| 41 |
+
if word is None:
|
| 42 |
+
return self.UNK_token
|
| 43 |
+
cleaned = word.strip().lower()
|
| 44 |
+
if not cleaned:
|
| 45 |
+
return self.PAD_token
|
| 46 |
+
return self._word_to_idx["<unk>"]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CustomDataset(Dataset):
|
| 50 |
+
def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
|
| 51 |
+
self.audio_file_path = args.audio_file_path
|
| 52 |
+
self.textgrid_file_path = args.textgrid_file_path
|
| 53 |
+
self.default_pose_file = "./demo/examples/2_scott_0_1_1.npz"
|
| 54 |
+
|
| 55 |
+
self.args = args
|
| 56 |
+
self.loader_type = loader_type
|
| 57 |
+
|
| 58 |
+
self.rank = 0
|
| 59 |
+
self.ori_stride = self.args.stride
|
| 60 |
+
self.ori_length = self.args.pose_length
|
| 61 |
+
self.alignment = [0,0] # for trinity
|
| 62 |
+
|
| 63 |
+
self.ori_joint_list = joints_list[self.args.ori_joints]
|
| 64 |
+
self.tar_joint_list = joints_list[self.args.tar_joints]
|
| 65 |
+
if 'smplx' in self.args.pose_rep:
|
| 66 |
+
self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
|
| 67 |
+
self.joints = len(list(self.tar_joint_list.keys()))
|
| 68 |
+
for joint_name in self.tar_joint_list:
|
| 69 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 70 |
+
else:
|
| 71 |
+
self.joints = len(list(self.ori_joint_list.keys()))+1
|
| 72 |
+
self.joint_mask = np.zeros(self.joints*3)
|
| 73 |
+
for joint_name in self.tar_joint_list:
|
| 74 |
+
if joint_name == "Hips":
|
| 75 |
+
self.joint_mask[3:6] = 1
|
| 76 |
+
else:
|
| 77 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 78 |
+
# select trainable joints
|
| 79 |
+
self.smplx = smplx.create(
|
| 80 |
+
self.args.data_path_1+"smplx_models/",
|
| 81 |
+
model_type='smplx',
|
| 82 |
+
gender='NEUTRAL_2020',
|
| 83 |
+
use_face_contour=False,
|
| 84 |
+
num_betas=300,
|
| 85 |
+
num_expression_coeffs=100,
|
| 86 |
+
ext='npz',
|
| 87 |
+
use_pca=False,
|
| 88 |
+
).eval()
|
| 89 |
+
|
| 90 |
+
if loader_type == 'test':
|
| 91 |
+
# In demo/test mode, skip dataset CSV and use provided paths
|
| 92 |
+
self.selected_file = pd.DataFrame([{
|
| 93 |
+
'id': 'demo_0',
|
| 94 |
+
'audio_path': self.args.audio_file_path or './demo/examples/2_scott_0_1_1.wav',
|
| 95 |
+
'textgrid_path': self.args.textgrid_file_path or None,
|
| 96 |
+
'pose_path': self.default_pose_file,
|
| 97 |
+
}])
|
| 98 |
+
else:
|
| 99 |
+
split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
|
| 100 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 101 |
+
if args.additional_data and loader_type == 'train':
|
| 102 |
+
split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 103 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 104 |
+
if self.selected_file.empty:
|
| 105 |
+
logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 106 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 107 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 108 |
+
self.data_dir = args.data_path
|
| 109 |
+
|
| 110 |
+
if loader_type == "test":
|
| 111 |
+
self.args.multi_length_training = [1.0]
|
| 112 |
+
self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
|
| 113 |
+
self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
|
| 114 |
+
if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
|
| 115 |
+
self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
|
| 116 |
+
|
| 117 |
+
if args.word_rep is not None:
|
| 118 |
+
vocab_path = f"{args.data_path}weights/vocab.pkl"
|
| 119 |
+
if loader_type == 'test':
|
| 120 |
+
logger.info("Instantiating fallback vocabulary for test loader")
|
| 121 |
+
self.lang_model = _FallbackLangModel()
|
| 122 |
+
elif os.path.exists(vocab_path):
|
| 123 |
+
with open(vocab_path, 'rb') as f:
|
| 124 |
+
self.lang_model = pickle.load(f)
|
| 125 |
+
else:
|
| 126 |
+
logger.warning(f"vocab.pkl not found at {vocab_path}, using fallback vocabulary")
|
| 127 |
+
self.lang_model = _FallbackLangModel()
|
| 128 |
+
else:
|
| 129 |
+
self.lang_model = None
|
| 130 |
+
|
| 131 |
+
preloaded_dir = self.args.tmp_dir+'/' + loader_type + f"/{args.pose_rep}_cache"
|
| 132 |
+
|
| 133 |
+
if self.args.beat_align and loader_type != 'test':
|
| 134 |
+
if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
|
| 135 |
+
self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 136 |
+
self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 137 |
+
else:
|
| 138 |
+
self.avg_vel = None
|
| 139 |
+
|
| 140 |
+
if build_cache and self.rank == 0:
|
| 141 |
+
self.build_cache(preloaded_dir)
|
| 142 |
+
self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
|
| 143 |
+
with self.lmdb_env.begin() as txn:
|
| 144 |
+
self.n_samples = txn.stat()["entries"]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def calculate_mean_velocity(self, save_path):
|
| 150 |
+
# Stub for demo mode: write zero velocity to avoid heavy computation
|
| 151 |
+
avg_vel = np.zeros(55)
|
| 152 |
+
np.save(save_path, avg_vel)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def build_cache(self, preloaded_dir):
|
| 156 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 157 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 158 |
+
logger.info("Creating the dataset cache...")
|
| 159 |
+
if self.args.new_cache:
|
| 160 |
+
if os.path.exists(preloaded_dir):
|
| 161 |
+
shutil.rmtree(preloaded_dir)
|
| 162 |
+
if os.path.exists(preloaded_dir):
|
| 163 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 164 |
+
elif self.loader_type == "test":
|
| 165 |
+
self.cache_generation(
|
| 166 |
+
preloaded_dir, True,
|
| 167 |
+
0, 0,
|
| 168 |
+
is_test=True)
|
| 169 |
+
else:
|
| 170 |
+
self.cache_generation(
|
| 171 |
+
preloaded_dir, self.args.disable_filtering,
|
| 172 |
+
self.args.clean_first_seconds, self.args.clean_final_seconds,
|
| 173 |
+
is_test=False)
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return self.n_samples
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 180 |
+
# if "wav2vec2" in self.args.audio_rep:
|
| 181 |
+
# self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
|
| 182 |
+
# self.wav2vec_model.feature_extractor._freeze_parameters()
|
| 183 |
+
# self.wav2vec_model = self.wav2vec_model.cuda()
|
| 184 |
+
# self.wav2vec_model.eval()
|
| 185 |
+
|
| 186 |
+
self.n_out_samples = 0
|
| 187 |
+
# create db for samples
|
| 188 |
+
if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
|
| 189 |
+
dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 500))# 500G
|
| 190 |
+
n_filtered_out = defaultdict(int)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
#f_name = file_name["id"]
|
| 194 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 195 |
+
pose_file = self.default_pose_file
|
| 196 |
+
pose_each_file = []
|
| 197 |
+
trans_each_file = []
|
| 198 |
+
trans_v_each_file = []
|
| 199 |
+
shape_each_file = []
|
| 200 |
+
audio_each_file = []
|
| 201 |
+
facial_each_file = []
|
| 202 |
+
word_each_file = []
|
| 203 |
+
emo_each_file = []
|
| 204 |
+
sem_each_file = []
|
| 205 |
+
vid_each_file = []
|
| 206 |
+
id_pose = "tmp" #1_wayne_0_1_1
|
| 207 |
+
|
| 208 |
+
logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
|
| 209 |
+
if "smplx" in self.args.pose_rep:
|
| 210 |
+
pose_data = np.load(pose_file, allow_pickle=True)
|
| 211 |
+
assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
|
| 212 |
+
stride = int(30/self.args.pose_fps)
|
| 213 |
+
pose_each_file = pose_data["poses"][::stride]
|
| 214 |
+
trans_each_file = pose_data["trans"][::stride]
|
| 215 |
+
trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0]
|
| 216 |
+
trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2]
|
| 217 |
+
trans_v_each_file = np.zeros_like(trans_each_file)
|
| 218 |
+
trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0]
|
| 219 |
+
trans_v_each_file[0,0] = trans_v_each_file[1,0]
|
| 220 |
+
trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2]
|
| 221 |
+
trans_v_each_file[0,2] = trans_v_each_file[1,2]
|
| 222 |
+
trans_v_each_file[:,1] = trans_each_file[:,1]
|
| 223 |
+
shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
|
| 224 |
+
|
| 225 |
+
assert self.args.pose_fps == 30, "should 30"
|
| 226 |
+
m_data = np.load(pose_file, allow_pickle=True)
|
| 227 |
+
betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
|
| 228 |
+
n, c = poses.shape[0], poses.shape[1]
|
| 229 |
+
betas = betas.reshape(1, 300)
|
| 230 |
+
betas = np.tile(betas, (n, 1))
|
| 231 |
+
betas = torch.from_numpy(betas).float()
|
| 232 |
+
poses = torch.from_numpy(poses.reshape(n, c)).float()
|
| 233 |
+
exps = torch.from_numpy(exps.reshape(n, 100)).float()
|
| 234 |
+
trans = torch.from_numpy(trans.reshape(n, 3)).float()
|
| 235 |
+
max_length = 128 # δΈΊδ»δΉθΏιιθ¦δΈδΈͺmax_length
|
| 236 |
+
s, r = n//max_length, n%max_length
|
| 237 |
+
#print(n, s, r)
|
| 238 |
+
all_tensor = []
|
| 239 |
+
for i in range(s):
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
joints = self.smplx(
|
| 242 |
+
betas=betas[i*max_length:(i+1)*max_length],
|
| 243 |
+
transl=trans[i*max_length:(i+1)*max_length],
|
| 244 |
+
expression=exps[i*max_length:(i+1)*max_length],
|
| 245 |
+
jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
|
| 246 |
+
global_orient=poses[i*max_length:(i+1)*max_length,:3],
|
| 247 |
+
body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
|
| 248 |
+
left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
|
| 249 |
+
right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
|
| 250 |
+
return_verts=True,
|
| 251 |
+
return_joints=True,
|
| 252 |
+
leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
|
| 253 |
+
reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
|
| 254 |
+
)['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu()
|
| 255 |
+
all_tensor.append(joints)
|
| 256 |
+
if r != 0:
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
joints = self.smplx(
|
| 259 |
+
betas=betas[s*max_length:s*max_length+r],
|
| 260 |
+
transl=trans[s*max_length:s*max_length+r],
|
| 261 |
+
expression=exps[s*max_length:s*max_length+r],
|
| 262 |
+
jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
|
| 263 |
+
global_orient=poses[s*max_length:s*max_length+r,:3],
|
| 264 |
+
body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
|
| 265 |
+
left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
|
| 266 |
+
right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
|
| 267 |
+
return_verts=True,
|
| 268 |
+
return_joints=True,
|
| 269 |
+
leye_pose=poses[s*max_length:s*max_length+r, 69:72],
|
| 270 |
+
reye_pose=poses[s*max_length:s*max_length+r, 72:75],
|
| 271 |
+
)['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu()
|
| 272 |
+
all_tensor.append(joints)
|
| 273 |
+
joints = torch.cat(all_tensor, axis=0) # all, 4, 3
|
| 274 |
+
# print(joints.shape)
|
| 275 |
+
feetv = torch.zeros(joints.shape[1], joints.shape[0])
|
| 276 |
+
joints = joints.permute(1, 0, 2)
|
| 277 |
+
#print(joints.shape, feetv.shape)
|
| 278 |
+
feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1)
|
| 279 |
+
#print(feetv.shape)
|
| 280 |
+
contacts = (feetv < 0.01).numpy().astype(float)
|
| 281 |
+
# print(contacts.shape, contacts)
|
| 282 |
+
contacts = contacts.transpose(1, 0)
|
| 283 |
+
pose_each_file = pose_each_file * self.joint_mask
|
| 284 |
+
pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
|
| 285 |
+
pose_each_file = np.concatenate([pose_each_file, contacts], axis=1)
|
| 286 |
+
# print(pose_each_file.shape)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if self.args.facial_rep is not None:
|
| 290 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 291 |
+
facial_each_file = pose_data["expressions"][::stride]
|
| 292 |
+
if self.args.facial_norm:
|
| 293 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 294 |
+
|
| 295 |
+
if self.args.id_rep is not None:
|
| 296 |
+
vid_each_file = np.repeat(np.array(int(999)-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 297 |
+
|
| 298 |
+
if self.args.audio_rep is not None:
|
| 299 |
+
logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
|
| 300 |
+
audio_file = self.audio_file_path
|
| 301 |
+
if not os.path.exists(audio_file):
|
| 302 |
+
logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
|
| 303 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 304 |
+
|
| 305 |
+
audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy")
|
| 306 |
+
|
| 307 |
+
if self.args.audio_rep == "onset+amplitude":
|
| 308 |
+
audio_each_file, sr = librosa.load(audio_file)
|
| 309 |
+
audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
|
| 310 |
+
from numpy.lib import stride_tricks
|
| 311 |
+
frame_length = 1024
|
| 312 |
+
# hop_length = 512
|
| 313 |
+
shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
|
| 314 |
+
strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
|
| 315 |
+
rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
|
| 316 |
+
amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
|
| 317 |
+
# pad the last frame_length-1 samples
|
| 318 |
+
amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
|
| 319 |
+
audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
|
| 320 |
+
onset_array = np.zeros(len(audio_each_file), dtype=float)
|
| 321 |
+
onset_array[audio_onset_f] = 1.0
|
| 322 |
+
# print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
|
| 323 |
+
audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
elif self.args.audio_rep == "mfcc":
|
| 327 |
+
audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps))
|
| 328 |
+
audio_each_file = audio_each_file.transpose(1, 0)
|
| 329 |
+
# print(audio_each_file.shape, pose_each_file.shape)
|
| 330 |
+
if self.args.audio_norm and self.args.audio_rep == "wave16k":
|
| 331 |
+
audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
|
| 332 |
+
|
| 333 |
+
time_offset = 0
|
| 334 |
+
if self.args.word_rep is not None:
|
| 335 |
+
logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
|
| 336 |
+
word_file = self.textgrid_file_path
|
| 337 |
+
if not os.path.exists(word_file):
|
| 338 |
+
logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
|
| 339 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 340 |
+
word_save_path = f"{self.data_dir}{self.args.t_pre_encoder}/{id_pose}.npy"
|
| 341 |
+
|
| 342 |
+
tgrid = tg.TextGrid.fromFile(word_file)
|
| 343 |
+
|
| 344 |
+
for i in range(pose_each_file.shape[0]):
|
| 345 |
+
found_flag = False
|
| 346 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 347 |
+
j_last = 0
|
| 348 |
+
for j, word in enumerate(tgrid[0]):
|
| 349 |
+
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
|
| 350 |
+
if word_s<=current_time and current_time<=word_e:
|
| 351 |
+
if word_n == " ":
|
| 352 |
+
word_each_file.append(self.lang_model.PAD_token)
|
| 353 |
+
else:
|
| 354 |
+
word_each_file.append(self.lang_model.get_word_index(word_n))
|
| 355 |
+
found_flag = True
|
| 356 |
+
j_last = j
|
| 357 |
+
break
|
| 358 |
+
else: continue
|
| 359 |
+
if not found_flag:
|
| 360 |
+
word_each_file.append(self.lang_model.UNK_token)
|
| 361 |
+
word_each_file = np.array(word_each_file)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if self.args.emo_rep is not None:
|
| 366 |
+
logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
|
| 367 |
+
rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
|
| 368 |
+
if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
|
| 369 |
+
if start >= 1 and start <= 64:
|
| 370 |
+
score = 0
|
| 371 |
+
elif start >= 65 and start <= 72:
|
| 372 |
+
score = 1
|
| 373 |
+
elif start >= 73 and start <= 80:
|
| 374 |
+
score = 2
|
| 375 |
+
elif start >= 81 and start <= 86:
|
| 376 |
+
score = 3
|
| 377 |
+
elif start >= 87 and start <= 94:
|
| 378 |
+
score = 4
|
| 379 |
+
elif start >= 95 and start <= 102:
|
| 380 |
+
score = 5
|
| 381 |
+
elif start >= 103 and start <= 110:
|
| 382 |
+
score = 6
|
| 383 |
+
elif start >= 111 and start <= 118:
|
| 384 |
+
score = 7
|
| 385 |
+
else: pass
|
| 386 |
+
else:
|
| 387 |
+
# you may denote as unknown in the future
|
| 388 |
+
score = 0
|
| 389 |
+
emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 390 |
+
#print(emo_each_file)
|
| 391 |
+
|
| 392 |
+
if self.args.sem_rep is not None:
|
| 393 |
+
logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
|
| 394 |
+
sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
|
| 395 |
+
sem_all = pd.read_csv(sem_file,
|
| 396 |
+
sep='\t',
|
| 397 |
+
names=["name", "start_time", "end_time", "duration", "score", "keywords"])
|
| 398 |
+
# we adopt motion-level semantic score here.
|
| 399 |
+
for i in range(pose_each_file.shape[0]):
|
| 400 |
+
found_flag = False
|
| 401 |
+
for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
|
| 402 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 403 |
+
if start<=current_time and current_time<=end:
|
| 404 |
+
sem_each_file.append(score)
|
| 405 |
+
found_flag=True
|
| 406 |
+
break
|
| 407 |
+
else: continue
|
| 408 |
+
if not found_flag: sem_each_file.append(0.)
|
| 409 |
+
sem_each_file = np.array(sem_each_file)
|
| 410 |
+
#print(sem_each_file)
|
| 411 |
+
|
| 412 |
+
filtered_result = self._sample_from_clip(
|
| 413 |
+
dst_lmdb_env,
|
| 414 |
+
audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file,
|
| 415 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 416 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 417 |
+
)
|
| 418 |
+
for type in filtered_result.keys():
|
| 419 |
+
n_filtered_out[type] += filtered_result[type]
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
#### ---------for_end------------ ####
|
| 425 |
+
with dst_lmdb_env.begin() as txn:
|
| 426 |
+
logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
|
| 427 |
+
n_total_filtered = 0
|
| 428 |
+
for type, n_filtered in n_filtered_out.items():
|
| 429 |
+
logger.info("{}: {}".format(type, n_filtered))
|
| 430 |
+
n_total_filtered += n_filtered
|
| 431 |
+
logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
|
| 432 |
+
n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
|
| 433 |
+
dst_lmdb_env.sync()
|
| 434 |
+
dst_lmdb_env.close()
|
| 435 |
+
|
| 436 |
+
def _sample_from_clip(
|
| 437 |
+
self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file,
|
| 438 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 439 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 440 |
+
):
|
| 441 |
+
"""
|
| 442 |
+
for data cleaning, we ignore the data for first and final n s
|
| 443 |
+
for test, we return all data
|
| 444 |
+
"""
|
| 445 |
+
# audio_start = int(self.alignment[0] * self.args.audio_fps)
|
| 446 |
+
# pose_start = int(self.alignment[1] * self.args.pose_fps)
|
| 447 |
+
#logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
|
| 448 |
+
# audio_each_file = audio_each_file[audio_start:]
|
| 449 |
+
# pose_each_file = pose_each_file[pose_start:]
|
| 450 |
+
# trans_each_file =
|
| 451 |
+
#logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
|
| 452 |
+
#print(pose_each_file.shape)
|
| 453 |
+
round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
|
| 454 |
+
#print(round_seconds_skeleton)
|
| 455 |
+
if audio_each_file is not None:
|
| 456 |
+
if self.args.audio_rep != "wave16k":
|
| 457 |
+
round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
|
| 458 |
+
elif self.args.audio_rep == "mfcc":
|
| 459 |
+
round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps
|
| 460 |
+
else:
|
| 461 |
+
round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr
|
| 462 |
+
if facial_each_file is not None:
|
| 463 |
+
round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
|
| 464 |
+
logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
|
| 465 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 466 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 467 |
+
if round_seconds_skeleton != max_round:
|
| 468 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 469 |
+
else:
|
| 470 |
+
logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
|
| 471 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
|
| 472 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton)
|
| 473 |
+
if round_seconds_skeleton != max_round:
|
| 474 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 475 |
+
|
| 476 |
+
clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
|
| 477 |
+
clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
|
| 478 |
+
clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
for ratio in self.args.multi_length_training:
|
| 482 |
+
if is_test:# stride = length for test
|
| 483 |
+
cut_length = clip_e_f_pose - clip_s_f_pose
|
| 484 |
+
self.args.stride = cut_length
|
| 485 |
+
self.max_length = cut_length
|
| 486 |
+
else:
|
| 487 |
+
self.args.stride = int(ratio*self.ori_stride)
|
| 488 |
+
cut_length = int(self.ori_length*ratio)
|
| 489 |
+
|
| 490 |
+
num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
|
| 491 |
+
logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
|
| 492 |
+
logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
|
| 493 |
+
|
| 494 |
+
if audio_each_file is not None:
|
| 495 |
+
audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
|
| 496 |
+
"""
|
| 497 |
+
for audio sr = 16000, fps = 15, pose_length = 34,
|
| 498 |
+
audio short length = 36266.7 -> 36266
|
| 499 |
+
this error is fine.
|
| 500 |
+
"""
|
| 501 |
+
logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
|
| 502 |
+
|
| 503 |
+
n_filtered_out = defaultdict(int)
|
| 504 |
+
sample_pose_list = []
|
| 505 |
+
sample_audio_list = []
|
| 506 |
+
sample_facial_list = []
|
| 507 |
+
sample_shape_list = []
|
| 508 |
+
sample_word_list = []
|
| 509 |
+
sample_emo_list = []
|
| 510 |
+
sample_sem_list = []
|
| 511 |
+
sample_vid_list = []
|
| 512 |
+
sample_trans_list = []
|
| 513 |
+
sample_trans_v_list = []
|
| 514 |
+
|
| 515 |
+
for i in range(num_subdivision): # cut into around 2s chip, (self npose)
|
| 516 |
+
start_idx = clip_s_f_pose + i * self.args.stride
|
| 517 |
+
fin_idx = start_idx + cut_length
|
| 518 |
+
sample_pose = pose_each_file[start_idx:fin_idx]
|
| 519 |
+
|
| 520 |
+
sample_trans = trans_each_file[start_idx:fin_idx]
|
| 521 |
+
sample_trans_v = trans_v_each_file[start_idx:fin_idx]
|
| 522 |
+
sample_shape = shape_each_file[start_idx:fin_idx]
|
| 523 |
+
# print(sample_pose.shape)
|
| 524 |
+
if self.args.audio_rep is not None:
|
| 525 |
+
audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
|
| 526 |
+
audio_end = audio_start + audio_short_length
|
| 527 |
+
sample_audio = audio_each_file[audio_start:audio_end]
|
| 528 |
+
else:
|
| 529 |
+
sample_audio = np.array([-1])
|
| 530 |
+
sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
|
| 531 |
+
sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
|
| 532 |
+
sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
|
| 533 |
+
sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
|
| 534 |
+
sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
|
| 535 |
+
|
| 536 |
+
if sample_pose.any() != None:
|
| 537 |
+
# filtering motion skeleton data
|
| 538 |
+
sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
|
| 539 |
+
is_correct_motion = (sample_pose is not None)
|
| 540 |
+
if is_correct_motion or disable_filtering:
|
| 541 |
+
sample_pose_list.append(sample_pose)
|
| 542 |
+
sample_audio_list.append(sample_audio)
|
| 543 |
+
sample_facial_list.append(sample_facial)
|
| 544 |
+
sample_shape_list.append(sample_shape)
|
| 545 |
+
sample_word_list.append(sample_word)
|
| 546 |
+
sample_vid_list.append(sample_vid)
|
| 547 |
+
sample_emo_list.append(sample_emo)
|
| 548 |
+
sample_sem_list.append(sample_sem)
|
| 549 |
+
sample_trans_list.append(sample_trans)
|
| 550 |
+
sample_trans_v_list.append(sample_trans_v)
|
| 551 |
+
else:
|
| 552 |
+
n_filtered_out[filtering_message] += 1
|
| 553 |
+
|
| 554 |
+
if len(sample_pose_list) > 0:
|
| 555 |
+
with dst_lmdb_env.begin(write=True) as txn:
|
| 556 |
+
for pose, audio, facial, shape, word, vid, emo, sem, trans,trans_v in zip(
|
| 557 |
+
sample_pose_list,
|
| 558 |
+
sample_audio_list,
|
| 559 |
+
sample_facial_list,
|
| 560 |
+
sample_shape_list,
|
| 561 |
+
sample_word_list,
|
| 562 |
+
sample_vid_list,
|
| 563 |
+
sample_emo_list,
|
| 564 |
+
sample_sem_list,
|
| 565 |
+
sample_trans_list,
|
| 566 |
+
sample_trans_v_list,):
|
| 567 |
+
k = "{:005}".format(self.n_out_samples).encode("ascii")
|
| 568 |
+
v = [pose, audio, facial, shape, word, emo, sem, vid, trans,trans_v]
|
| 569 |
+
v = pickle.dumps(v,5)
|
| 570 |
+
txn.put(k, v)
|
| 571 |
+
self.n_out_samples += 1
|
| 572 |
+
return n_filtered_out
|
| 573 |
+
|
| 574 |
+
def __getitem__(self, idx):
|
| 575 |
+
with self.lmdb_env.begin(write=False) as txn:
|
| 576 |
+
key = "{:005}".format(idx).encode("ascii")
|
| 577 |
+
sample = txn.get(key)
|
| 578 |
+
sample = pickle.loads(sample)
|
| 579 |
+
tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans,trans_v = sample
|
| 580 |
+
#print(in_shape)
|
| 581 |
+
#vid = torch.from_numpy(vid).int()
|
| 582 |
+
emo = torch.from_numpy(emo).int()
|
| 583 |
+
sem = torch.from_numpy(sem).float()
|
| 584 |
+
in_audio = torch.from_numpy(in_audio).float()
|
| 585 |
+
in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
|
| 586 |
+
if self.loader_type == "test":
|
| 587 |
+
tar_pose = torch.from_numpy(tar_pose).float()
|
| 588 |
+
trans = torch.from_numpy(trans).float()
|
| 589 |
+
trans_v = torch.from_numpy(trans_v).float()
|
| 590 |
+
in_facial = torch.from_numpy(in_facial).float()
|
| 591 |
+
vid = torch.from_numpy(vid).float()
|
| 592 |
+
in_shape = torch.from_numpy(in_shape).float()
|
| 593 |
+
else:
|
| 594 |
+
in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 595 |
+
trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
|
| 596 |
+
trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float()
|
| 597 |
+
vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
|
| 598 |
+
tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
|
| 599 |
+
in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
|
| 600 |
+
return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans,"trans_v":trans_v}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class MotionPreprocessor:
|
| 604 |
+
def __init__(self, skeletons):
|
| 605 |
+
self.skeletons = skeletons
|
| 606 |
+
#self.mean_pose = mean_pose
|
| 607 |
+
self.filtering_message = "PASS"
|
| 608 |
+
|
| 609 |
+
def get(self):
|
| 610 |
+
assert (self.skeletons is not None)
|
| 611 |
+
|
| 612 |
+
# filtering
|
| 613 |
+
if self.skeletons is not None:
|
| 614 |
+
if self.check_pose_diff():
|
| 615 |
+
self.skeletons = []
|
| 616 |
+
self.filtering_message = "pose"
|
| 617 |
+
# elif self.check_spine_angle():
|
| 618 |
+
# self.skeletons = []
|
| 619 |
+
# self.filtering_message = "spine angle"
|
| 620 |
+
# elif self.check_static_motion():
|
| 621 |
+
# self.skeletons = []
|
| 622 |
+
# self.filtering_message = "motion"
|
| 623 |
+
|
| 624 |
+
# if self.skeletons is not None:
|
| 625 |
+
# self.skeletons = self.skeletons.tolist()
|
| 626 |
+
# for i, frame in enumerate(self.skeletons):
|
| 627 |
+
# assert not np.isnan(self.skeletons[i]).any() # missing joints
|
| 628 |
+
|
| 629 |
+
return self.skeletons, self.filtering_message
|
| 630 |
+
|
| 631 |
+
def check_static_motion(self, verbose=True):
|
| 632 |
+
def get_variance(skeleton, joint_idx):
|
| 633 |
+
wrist_pos = skeleton[:, joint_idx]
|
| 634 |
+
variance = np.sum(np.var(wrist_pos, axis=0))
|
| 635 |
+
return variance
|
| 636 |
+
|
| 637 |
+
left_arm_var = get_variance(self.skeletons, 6)
|
| 638 |
+
right_arm_var = get_variance(self.skeletons, 9)
|
| 639 |
+
|
| 640 |
+
th = 0.0014 # exclude 13110
|
| 641 |
+
# th = 0.002 # exclude 16905
|
| 642 |
+
if left_arm_var < th and right_arm_var < th:
|
| 643 |
+
if verbose:
|
| 644 |
+
print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 645 |
+
return True
|
| 646 |
+
else:
|
| 647 |
+
if verbose:
|
| 648 |
+
print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 649 |
+
return False
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def check_pose_diff(self, verbose=False):
|
| 653 |
+
# diff = np.abs(self.skeletons - self.mean_pose) # 186*1
|
| 654 |
+
# diff = np.mean(diff)
|
| 655 |
+
|
| 656 |
+
# # th = 0.017
|
| 657 |
+
# th = 0.02 #0.02 # exclude 3594
|
| 658 |
+
# if diff < th:
|
| 659 |
+
# if verbose:
|
| 660 |
+
# print("skip - check_pose_diff {:.5f}".format(diff))
|
| 661 |
+
# return True
|
| 662 |
+
# # th = 3.5 #0.02 # exclude 3594
|
| 663 |
+
# # if 3.5 < diff < 5:
|
| 664 |
+
# # if verbose:
|
| 665 |
+
# # print("skip - check_pose_diff {:.5f}".format(diff))
|
| 666 |
+
# # return True
|
| 667 |
+
# else:
|
| 668 |
+
# if verbose:
|
| 669 |
+
# print("pass - check_pose_diff {:.5f}".format(diff))
|
| 670 |
+
return False
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def check_spine_angle(self, verbose=True):
|
| 674 |
+
def angle_between(v1, v2):
|
| 675 |
+
v1_u = v1 / np.linalg.norm(v1)
|
| 676 |
+
v2_u = v2 / np.linalg.norm(v2)
|
| 677 |
+
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
|
| 678 |
+
|
| 679 |
+
angles = []
|
| 680 |
+
for i in range(self.skeletons.shape[0]):
|
| 681 |
+
spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
|
| 682 |
+
angle = angle_between(spine_vec, [0, -1, 0])
|
| 683 |
+
angles.append(angle)
|
| 684 |
+
|
| 685 |
+
if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
|
| 686 |
+
# if np.rad2deg(max(angles)) > 20: # exclude 8270
|
| 687 |
+
if verbose:
|
| 688 |
+
print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
|
| 689 |
+
return True
|
| 690 |
+
else:
|
| 691 |
+
if verbose:
|
| 692 |
+
print("pass - check_spine_angle {:.5f}".format(max(angles)))
|
| 693 |
+
return False
|
dataloaders/beat_smplx2020.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import textgrid as tg
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import glob
|
| 11 |
+
import json
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import pyarrow
|
| 18 |
+
import librosa
|
| 19 |
+
import smplx
|
| 20 |
+
|
| 21 |
+
from .build_vocab import Vocab
|
| 22 |
+
from .utils.audio_features import Wav2Vec2Model
|
| 23 |
+
from .data_tools import joints_list
|
| 24 |
+
from .utils import rotation_conversions as rc
|
| 25 |
+
from .utils import other_tools
|
| 26 |
+
|
| 27 |
+
class CustomDataset(Dataset):
|
| 28 |
+
def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
|
| 29 |
+
self.args = args
|
| 30 |
+
self.loader_type = loader_type
|
| 31 |
+
|
| 32 |
+
self.rank = dist.get_rank()
|
| 33 |
+
self.ori_stride = self.args.stride
|
| 34 |
+
self.ori_length = self.args.pose_length
|
| 35 |
+
self.alignment = [0,0] # for trinity
|
| 36 |
+
|
| 37 |
+
self.ori_joint_list = joints_list[self.args.ori_joints]
|
| 38 |
+
self.tar_joint_list = joints_list[self.args.tar_joints]
|
| 39 |
+
if 'smplx' in self.args.pose_rep:
|
| 40 |
+
self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
|
| 41 |
+
self.joints = len(list(self.ori_joint_list.keys()))
|
| 42 |
+
for joint_name in self.tar_joint_list:
|
| 43 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 44 |
+
else:
|
| 45 |
+
self.joints = len(list(self.ori_joint_list.keys()))+1
|
| 46 |
+
self.joint_mask = np.zeros(self.joints*3)
|
| 47 |
+
for joint_name in self.tar_joint_list:
|
| 48 |
+
if joint_name == "Hips":
|
| 49 |
+
self.joint_mask[3:6] = 1
|
| 50 |
+
else:
|
| 51 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 52 |
+
# select trainable joints
|
| 53 |
+
|
| 54 |
+
split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
|
| 55 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 56 |
+
if args.additional_data and loader_type == 'train':
|
| 57 |
+
split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 58 |
+
#self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 59 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 60 |
+
if self.selected_file.empty:
|
| 61 |
+
logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 62 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 63 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 64 |
+
self.data_dir = args.data_path
|
| 65 |
+
|
| 66 |
+
if loader_type == "test":
|
| 67 |
+
self.args.multi_length_training = [1.0]
|
| 68 |
+
self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
|
| 69 |
+
self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
|
| 70 |
+
if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
|
| 71 |
+
self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
|
| 72 |
+
|
| 73 |
+
if args.word_rep is not None:
|
| 74 |
+
with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
|
| 75 |
+
self.lang_model = pickle.load(f)
|
| 76 |
+
|
| 77 |
+
preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
|
| 78 |
+
# if args.pose_norm:
|
| 79 |
+
# # careful for rotation vectors
|
| 80 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
|
| 81 |
+
# self.calculate_mean_pose()
|
| 82 |
+
# self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy")
|
| 83 |
+
# self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy")
|
| 84 |
+
# if args.audio_norm:
|
| 85 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"):
|
| 86 |
+
# self.calculate_mean_audio()
|
| 87 |
+
# self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy")
|
| 88 |
+
# self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy")
|
| 89 |
+
# if args.facial_norm:
|
| 90 |
+
# if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"):
|
| 91 |
+
# self.calculate_mean_face()
|
| 92 |
+
# self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")
|
| 93 |
+
# self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")
|
| 94 |
+
if self.args.beat_align:
|
| 95 |
+
if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"):
|
| 96 |
+
self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 97 |
+
self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy")
|
| 98 |
+
|
| 99 |
+
if build_cache and self.rank == 0:
|
| 100 |
+
self.build_cache(preloaded_dir)
|
| 101 |
+
self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
|
| 102 |
+
with self.lmdb_env.begin() as txn:
|
| 103 |
+
self.n_samples = txn.stat()["entries"]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def calculate_mean_velocity(self, save_path):
|
| 107 |
+
self.smplx = smplx.create(
|
| 108 |
+
self.args.data_path_1+"smplx_models/",
|
| 109 |
+
model_type='smplx',
|
| 110 |
+
gender='NEUTRAL_2020',
|
| 111 |
+
use_face_contour=False,
|
| 112 |
+
num_betas=300,
|
| 113 |
+
num_expression_coeffs=100,
|
| 114 |
+
ext='npz',
|
| 115 |
+
use_pca=False,
|
| 116 |
+
).cuda().eval()
|
| 117 |
+
dir_p = self.data_dir + self.args.pose_rep + "/"
|
| 118 |
+
all_list = []
|
| 119 |
+
from tqdm import tqdm
|
| 120 |
+
for tar in tqdm(os.listdir(dir_p)):
|
| 121 |
+
if tar.endswith(".npz"):
|
| 122 |
+
m_data = np.load(dir_p+tar, allow_pickle=True)
|
| 123 |
+
betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"]
|
| 124 |
+
n, c = poses.shape[0], poses.shape[1]
|
| 125 |
+
betas = betas.reshape(1, 300)
|
| 126 |
+
betas = np.tile(betas, (n, 1))
|
| 127 |
+
betas = torch.from_numpy(betas).cuda().float()
|
| 128 |
+
poses = torch.from_numpy(poses.reshape(n, c)).cuda().float()
|
| 129 |
+
exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float()
|
| 130 |
+
trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float()
|
| 131 |
+
max_length = 128
|
| 132 |
+
s, r = n//max_length, n%max_length
|
| 133 |
+
#print(n, s, r)
|
| 134 |
+
all_tensor = []
|
| 135 |
+
for i in range(s):
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
joints = self.smplx(
|
| 138 |
+
betas=betas[i*max_length:(i+1)*max_length],
|
| 139 |
+
transl=trans[i*max_length:(i+1)*max_length],
|
| 140 |
+
expression=exps[i*max_length:(i+1)*max_length],
|
| 141 |
+
jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69],
|
| 142 |
+
global_orient=poses[i*max_length:(i+1)*max_length,:3],
|
| 143 |
+
body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3],
|
| 144 |
+
left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3],
|
| 145 |
+
right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3],
|
| 146 |
+
return_verts=True,
|
| 147 |
+
return_joints=True,
|
| 148 |
+
leye_pose=poses[i*max_length:(i+1)*max_length, 69:72],
|
| 149 |
+
reye_pose=poses[i*max_length:(i+1)*max_length, 72:75],
|
| 150 |
+
)['joints'][:, :55, :].reshape(max_length, 55*3)
|
| 151 |
+
all_tensor.append(joints)
|
| 152 |
+
if r != 0:
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
joints = self.smplx(
|
| 155 |
+
betas=betas[s*max_length:s*max_length+r],
|
| 156 |
+
transl=trans[s*max_length:s*max_length+r],
|
| 157 |
+
expression=exps[s*max_length:s*max_length+r],
|
| 158 |
+
jaw_pose=poses[s*max_length:s*max_length+r, 66:69],
|
| 159 |
+
global_orient=poses[s*max_length:s*max_length+r,:3],
|
| 160 |
+
body_pose=poses[s*max_length:s*max_length+r,3:21*3+3],
|
| 161 |
+
left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3],
|
| 162 |
+
right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3],
|
| 163 |
+
return_verts=True,
|
| 164 |
+
return_joints=True,
|
| 165 |
+
leye_pose=poses[s*max_length:s*max_length+r, 69:72],
|
| 166 |
+
reye_pose=poses[s*max_length:s*max_length+r, 72:75],
|
| 167 |
+
)['joints'][:, :55, :].reshape(r, 55*3)
|
| 168 |
+
all_tensor.append(joints)
|
| 169 |
+
joints = torch.cat(all_tensor, axis=0)
|
| 170 |
+
joints = joints.permute(1, 0)
|
| 171 |
+
dt = 1/30
|
| 172 |
+
# first steps is forward diff (t+1 - t) / dt
|
| 173 |
+
init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
|
| 174 |
+
# middle steps are second order (t+1 - t-1) / 2dt
|
| 175 |
+
middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
|
| 176 |
+
# last step is backward diff (t - t-1) / dt
|
| 177 |
+
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
| 178 |
+
#print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape)
|
| 179 |
+
vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3)
|
| 180 |
+
#print(vel_seq.shape)
|
| 181 |
+
#.permute(1, 0).reshape(n, 55, 3)
|
| 182 |
+
vel_seq_np = vel_seq.cpu().numpy()
|
| 183 |
+
vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55
|
| 184 |
+
all_list.append(vel_joints_np)
|
| 185 |
+
avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55
|
| 186 |
+
np.save(save_path, avg_vel)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def build_cache(self, preloaded_dir):
|
| 190 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 191 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 192 |
+
logger.info("Creating the dataset cache...")
|
| 193 |
+
if self.args.new_cache:
|
| 194 |
+
if os.path.exists(preloaded_dir):
|
| 195 |
+
shutil.rmtree(preloaded_dir)
|
| 196 |
+
if os.path.exists(preloaded_dir):
|
| 197 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 198 |
+
elif self.loader_type == "test":
|
| 199 |
+
self.cache_generation(
|
| 200 |
+
preloaded_dir, True,
|
| 201 |
+
0, 0,
|
| 202 |
+
is_test=True)
|
| 203 |
+
else:
|
| 204 |
+
self.cache_generation(
|
| 205 |
+
preloaded_dir, self.args.disable_filtering,
|
| 206 |
+
self.args.clean_first_seconds, self.args.clean_final_seconds,
|
| 207 |
+
is_test=False)
|
| 208 |
+
|
| 209 |
+
def __len__(self):
|
| 210 |
+
return self.n_samples
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 214 |
+
# if "wav2vec2" in self.args.audio_rep:
|
| 215 |
+
# self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h")
|
| 216 |
+
# self.wav2vec_model.feature_extractor._freeze_parameters()
|
| 217 |
+
# self.wav2vec_model = self.wav2vec_model.cuda()
|
| 218 |
+
# self.wav2vec_model.eval()
|
| 219 |
+
|
| 220 |
+
self.n_out_samples = 0
|
| 221 |
+
# create db for samples
|
| 222 |
+
if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
|
| 223 |
+
dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
|
| 224 |
+
n_filtered_out = defaultdict(int)
|
| 225 |
+
|
| 226 |
+
for index, file_name in self.selected_file.iterrows():
|
| 227 |
+
f_name = file_name["id"]
|
| 228 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 229 |
+
pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
|
| 230 |
+
pose_each_file = []
|
| 231 |
+
trans_each_file = []
|
| 232 |
+
shape_each_file = []
|
| 233 |
+
audio_each_file = []
|
| 234 |
+
facial_each_file = []
|
| 235 |
+
word_each_file = []
|
| 236 |
+
emo_each_file = []
|
| 237 |
+
sem_each_file = []
|
| 238 |
+
vid_each_file = []
|
| 239 |
+
id_pose = f_name #1_wayne_0_1_1
|
| 240 |
+
|
| 241 |
+
logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
|
| 242 |
+
if "smplx" in self.args.pose_rep:
|
| 243 |
+
pose_data = np.load(pose_file, allow_pickle=True)
|
| 244 |
+
assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
|
| 245 |
+
stride = int(30/self.args.pose_fps)
|
| 246 |
+
pose_each_file = pose_data["poses"][::stride] * self.joint_mask
|
| 247 |
+
trans_each_file = pose_data["trans"][::stride]
|
| 248 |
+
shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
|
| 249 |
+
if self.args.facial_rep is not None:
|
| 250 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 251 |
+
facial_each_file = pose_data["expressions"][::stride]
|
| 252 |
+
if self.args.facial_norm:
|
| 253 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 254 |
+
|
| 255 |
+
else:
|
| 256 |
+
assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
|
| 257 |
+
stride = int(120/self.args.pose_fps)
|
| 258 |
+
with open(pose_file, "r") as pose_data:
|
| 259 |
+
for j, line in enumerate(pose_data.readlines()):
|
| 260 |
+
if j < 431: continue
|
| 261 |
+
if j%stride != 0:continue
|
| 262 |
+
data = np.fromstring(line, dtype=float, sep=" ")
|
| 263 |
+
rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ")
|
| 264 |
+
rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3)
|
| 265 |
+
rot_data = rot_data.numpy() * self.joint_mask
|
| 266 |
+
|
| 267 |
+
pose_each_file.append(rot_data)
|
| 268 |
+
trans_each_file.append(data[:3])
|
| 269 |
+
|
| 270 |
+
pose_each_file = np.array(pose_each_file)
|
| 271 |
+
# print(pose_each_file.shape)
|
| 272 |
+
trans_each_file = np.array(trans_each_file)
|
| 273 |
+
shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 274 |
+
if self.args.facial_rep is not None:
|
| 275 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 276 |
+
facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json")
|
| 277 |
+
assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120'
|
| 278 |
+
stride = int(60/self.args.pose_fps)
|
| 279 |
+
if not os.path.exists(facial_file):
|
| 280 |
+
logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #")
|
| 281 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 282 |
+
continue
|
| 283 |
+
with open(facial_file, 'r') as facial_data_file:
|
| 284 |
+
facial_data = json.load(facial_data_file)
|
| 285 |
+
for j, frame_data in enumerate(facial_data['frames']):
|
| 286 |
+
if j%stride != 0:continue
|
| 287 |
+
facial_each_file.append(frame_data['weights'])
|
| 288 |
+
facial_each_file = np.array(facial_each_file)
|
| 289 |
+
if self.args.facial_norm:
|
| 290 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 291 |
+
|
| 292 |
+
if self.args.id_rep is not None:
|
| 293 |
+
vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 294 |
+
|
| 295 |
+
if self.args.audio_rep is not None:
|
| 296 |
+
logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #")
|
| 297 |
+
audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav")
|
| 298 |
+
if not os.path.exists(audio_file):
|
| 299 |
+
logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #")
|
| 300 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 301 |
+
continue
|
| 302 |
+
audio_each_file, sr = librosa.load(audio_file)
|
| 303 |
+
audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr)
|
| 304 |
+
if self.args.audio_rep == "onset+amplitude":
|
| 305 |
+
from numpy.lib import stride_tricks
|
| 306 |
+
frame_length = 1024
|
| 307 |
+
# hop_length = 512
|
| 308 |
+
shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length)
|
| 309 |
+
strides = (audio_each_file.strides[-1], audio_each_file.strides[-1])
|
| 310 |
+
rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides)
|
| 311 |
+
amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
|
| 312 |
+
# pad the last frame_length-1 samples
|
| 313 |
+
amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
|
| 314 |
+
audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames')
|
| 315 |
+
onset_array = np.zeros(len(audio_each_file), dtype=float)
|
| 316 |
+
onset_array[audio_onset_f] = 1.0
|
| 317 |
+
# print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape)
|
| 318 |
+
audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
|
| 319 |
+
elif self.args.audio_rep == "mfcc":
|
| 320 |
+
audio_each_file = librosa.feature.mfcc(audio_each_file, sr=self.args.audio_sr, n_mfcc=13, hop_length=int(self.args.audio_sr/self.args.audio_fps))
|
| 321 |
+
|
| 322 |
+
if self.args.audio_norm and self.args.audio_rep == "wave16k":
|
| 323 |
+
audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio
|
| 324 |
+
|
| 325 |
+
time_offset = 0
|
| 326 |
+
if self.args.word_rep is not None:
|
| 327 |
+
logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #")
|
| 328 |
+
word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid"
|
| 329 |
+
if not os.path.exists(word_file):
|
| 330 |
+
logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #")
|
| 331 |
+
self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index)
|
| 332 |
+
continue
|
| 333 |
+
tgrid = tg.TextGrid.fromFile(word_file)
|
| 334 |
+
if self.args.t_pre_encoder == "bert":
|
| 335 |
+
from transformers import AutoTokenizer, BertModel
|
| 336 |
+
tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True)
|
| 337 |
+
model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval()
|
| 338 |
+
list_word = []
|
| 339 |
+
all_hidden = []
|
| 340 |
+
max_len = 400
|
| 341 |
+
last = 0
|
| 342 |
+
word_token_mapping = []
|
| 343 |
+
first = True
|
| 344 |
+
for i, word in enumerate(tgrid[0]):
|
| 345 |
+
last = i
|
| 346 |
+
if (i%max_len != 0) or (i==0):
|
| 347 |
+
if word.mark == "":
|
| 348 |
+
list_word.append(".")
|
| 349 |
+
else:
|
| 350 |
+
list_word.append(word.mark)
|
| 351 |
+
else:
|
| 352 |
+
max_counter = max_len
|
| 353 |
+
str_word = ' '.join(map(str, list_word))
|
| 354 |
+
if first:
|
| 355 |
+
global_len = 0
|
| 356 |
+
end = -1
|
| 357 |
+
offset_word = []
|
| 358 |
+
for k, wordvalue in enumerate(list_word):
|
| 359 |
+
start = end+1
|
| 360 |
+
end = start+len(wordvalue)
|
| 361 |
+
offset_word.append((start, end))
|
| 362 |
+
#print(offset_word)
|
| 363 |
+
token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
|
| 364 |
+
#print(token_scan)
|
| 365 |
+
for start, end in offset_word:
|
| 366 |
+
sub_mapping = []
|
| 367 |
+
for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
|
| 368 |
+
if int(start) <= int(start_t) and int(end_t) <= int(end):
|
| 369 |
+
#print(i+global_len)
|
| 370 |
+
sub_mapping.append(i+global_len)
|
| 371 |
+
word_token_mapping.append(sub_mapping)
|
| 372 |
+
#print(len(word_token_mapping))
|
| 373 |
+
global_len = word_token_mapping[-1][-1] + 1
|
| 374 |
+
list_word = []
|
| 375 |
+
if word.mark == "":
|
| 376 |
+
list_word.append(".")
|
| 377 |
+
else:
|
| 378 |
+
list_word.append(word.mark)
|
| 379 |
+
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
inputs = tokenizer(str_word, return_tensors="pt")
|
| 382 |
+
outputs = model(**inputs)
|
| 383 |
+
last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
|
| 384 |
+
all_hidden.append(last_hidden_states)
|
| 385 |
+
|
| 386 |
+
#list_word = list_word[:10]
|
| 387 |
+
if list_word == []:
|
| 388 |
+
pass
|
| 389 |
+
else:
|
| 390 |
+
if first:
|
| 391 |
+
global_len = 0
|
| 392 |
+
str_word = ' '.join(map(str, list_word))
|
| 393 |
+
end = -1
|
| 394 |
+
offset_word = []
|
| 395 |
+
for k, wordvalue in enumerate(list_word):
|
| 396 |
+
start = end+1
|
| 397 |
+
end = start+len(wordvalue)
|
| 398 |
+
offset_word.append((start, end))
|
| 399 |
+
#print(offset_word)
|
| 400 |
+
token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
|
| 401 |
+
#print(token_scan)
|
| 402 |
+
for start, end in offset_word:
|
| 403 |
+
sub_mapping = []
|
| 404 |
+
for i, (start_t, end_t) in enumerate(token_scan[1:-1]):
|
| 405 |
+
if int(start) <= int(start_t) and int(end_t) <= int(end):
|
| 406 |
+
sub_mapping.append(i+global_len)
|
| 407 |
+
#print(sub_mapping)
|
| 408 |
+
word_token_mapping.append(sub_mapping)
|
| 409 |
+
#print(len(word_token_mapping))
|
| 410 |
+
with torch.no_grad():
|
| 411 |
+
inputs = tokenizer(str_word, return_tensors="pt")
|
| 412 |
+
outputs = model(**inputs)
|
| 413 |
+
last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
|
| 414 |
+
all_hidden.append(last_hidden_states)
|
| 415 |
+
last_hidden_states = np.concatenate(all_hidden, axis=0)
|
| 416 |
+
|
| 417 |
+
for i in range(pose_each_file.shape[0]):
|
| 418 |
+
found_flag = False
|
| 419 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 420 |
+
j_last = 0
|
| 421 |
+
for j, word in enumerate(tgrid[0]):
|
| 422 |
+
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
|
| 423 |
+
if word_s<=current_time and current_time<=word_e:
|
| 424 |
+
if self.args.word_cache and self.args.t_pre_encoder == 'bert':
|
| 425 |
+
mapping_index = word_token_mapping[j]
|
| 426 |
+
#print(mapping_index, word_s, word_e)
|
| 427 |
+
s_t = np.linspace(word_s, word_e, len(mapping_index)+1)
|
| 428 |
+
#print(s_t)
|
| 429 |
+
for tt, t_sep in enumerate(s_t[1:]):
|
| 430 |
+
if current_time <= t_sep:
|
| 431 |
+
#if len(mapping_index) > 1: print(mapping_index[tt])
|
| 432 |
+
word_each_file.append(last_hidden_states[mapping_index[tt]])
|
| 433 |
+
break
|
| 434 |
+
else:
|
| 435 |
+
if word_n == " ":
|
| 436 |
+
word_each_file.append(self.lang_model.PAD_token)
|
| 437 |
+
else:
|
| 438 |
+
word_each_file.append(self.lang_model.get_word_index(word_n))
|
| 439 |
+
found_flag = True
|
| 440 |
+
j_last = j
|
| 441 |
+
break
|
| 442 |
+
else: continue
|
| 443 |
+
if not found_flag:
|
| 444 |
+
if self.args.word_cache and self.args.t_pre_encoder == 'bert':
|
| 445 |
+
word_each_file.append(last_hidden_states[j_last])
|
| 446 |
+
else:
|
| 447 |
+
word_each_file.append(self.lang_model.UNK_token)
|
| 448 |
+
word_each_file = np.array(word_each_file)
|
| 449 |
+
#print(word_each_file.shape)
|
| 450 |
+
|
| 451 |
+
if self.args.emo_rep is not None:
|
| 452 |
+
logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #")
|
| 453 |
+
rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3])
|
| 454 |
+
if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6:
|
| 455 |
+
if start >= 1 and start <= 64:
|
| 456 |
+
score = 0
|
| 457 |
+
elif start >= 65 and start <= 72:
|
| 458 |
+
score = 1
|
| 459 |
+
elif start >= 73 and start <= 80:
|
| 460 |
+
score = 2
|
| 461 |
+
elif start >= 81 and start <= 86:
|
| 462 |
+
score = 3
|
| 463 |
+
elif start >= 87 and start <= 94:
|
| 464 |
+
score = 4
|
| 465 |
+
elif start >= 95 and start <= 102:
|
| 466 |
+
score = 5
|
| 467 |
+
elif start >= 103 and start <= 110:
|
| 468 |
+
score = 6
|
| 469 |
+
elif start >= 111 and start <= 118:
|
| 470 |
+
score = 7
|
| 471 |
+
else: pass
|
| 472 |
+
else:
|
| 473 |
+
# you may denote as unknown in the future
|
| 474 |
+
score = 0
|
| 475 |
+
emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 476 |
+
#print(emo_each_file)
|
| 477 |
+
|
| 478 |
+
if self.args.sem_rep is not None:
|
| 479 |
+
logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #")
|
| 480 |
+
sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt"
|
| 481 |
+
sem_all = pd.read_csv(sem_file,
|
| 482 |
+
sep='\t',
|
| 483 |
+
names=["name", "start_time", "end_time", "duration", "score", "keywords"])
|
| 484 |
+
# we adopt motion-level semantic score here.
|
| 485 |
+
for i in range(pose_each_file.shape[0]):
|
| 486 |
+
found_flag = False
|
| 487 |
+
for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])):
|
| 488 |
+
current_time = i/self.args.pose_fps + time_offset
|
| 489 |
+
if start<=current_time and current_time<=end:
|
| 490 |
+
sem_each_file.append(score)
|
| 491 |
+
found_flag=True
|
| 492 |
+
break
|
| 493 |
+
else: continue
|
| 494 |
+
if not found_flag: sem_each_file.append(0.)
|
| 495 |
+
sem_each_file = np.array(sem_each_file)
|
| 496 |
+
#print(sem_each_file)
|
| 497 |
+
|
| 498 |
+
filtered_result = self._sample_from_clip(
|
| 499 |
+
dst_lmdb_env,
|
| 500 |
+
audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
|
| 501 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 502 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 503 |
+
)
|
| 504 |
+
for type in filtered_result.keys():
|
| 505 |
+
n_filtered_out[type] += filtered_result[type]
|
| 506 |
+
|
| 507 |
+
with dst_lmdb_env.begin() as txn:
|
| 508 |
+
logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
|
| 509 |
+
n_total_filtered = 0
|
| 510 |
+
for type, n_filtered in n_filtered_out.items():
|
| 511 |
+
logger.info("{}: {}".format(type, n_filtered))
|
| 512 |
+
n_total_filtered += n_filtered
|
| 513 |
+
logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
|
| 514 |
+
n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
|
| 515 |
+
dst_lmdb_env.sync()
|
| 516 |
+
dst_lmdb_env.close()
|
| 517 |
+
|
| 518 |
+
def _sample_from_clip(
|
| 519 |
+
self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file,
|
| 520 |
+
vid_each_file, emo_each_file, sem_each_file,
|
| 521 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 522 |
+
):
|
| 523 |
+
"""
|
| 524 |
+
for data cleaning, we ignore the data for first and final n s
|
| 525 |
+
for test, we return all data
|
| 526 |
+
"""
|
| 527 |
+
# audio_start = int(self.alignment[0] * self.args.audio_fps)
|
| 528 |
+
# pose_start = int(self.alignment[1] * self.args.pose_fps)
|
| 529 |
+
#logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}")
|
| 530 |
+
# audio_each_file = audio_each_file[audio_start:]
|
| 531 |
+
# pose_each_file = pose_each_file[pose_start:]
|
| 532 |
+
# trans_each_file =
|
| 533 |
+
#logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}")
|
| 534 |
+
#print(pose_each_file.shape)
|
| 535 |
+
round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
|
| 536 |
+
#print(round_seconds_skeleton)
|
| 537 |
+
if audio_each_file != []:
|
| 538 |
+
round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s
|
| 539 |
+
if facial_each_file != []:
|
| 540 |
+
round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps
|
| 541 |
+
logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
|
| 542 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 543 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 544 |
+
if round_seconds_skeleton != max_round:
|
| 545 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 546 |
+
else:
|
| 547 |
+
logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
|
| 548 |
+
round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton)
|
| 549 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton)
|
| 550 |
+
if round_seconds_skeleton != max_round:
|
| 551 |
+
logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s")
|
| 552 |
+
|
| 553 |
+
clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
|
| 554 |
+
clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
|
| 555 |
+
clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
for ratio in self.args.multi_length_training:
|
| 559 |
+
if is_test:# stride = length for test
|
| 560 |
+
cut_length = clip_e_f_pose - clip_s_f_pose
|
| 561 |
+
self.args.stride = cut_length
|
| 562 |
+
self.max_length = cut_length
|
| 563 |
+
else:
|
| 564 |
+
self.args.stride = int(ratio*self.ori_stride)
|
| 565 |
+
cut_length = int(self.ori_length*ratio)
|
| 566 |
+
|
| 567 |
+
num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
|
| 568 |
+
logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
|
| 569 |
+
logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
|
| 570 |
+
|
| 571 |
+
if audio_each_file != []:
|
| 572 |
+
audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps)
|
| 573 |
+
"""
|
| 574 |
+
for audio sr = 16000, fps = 15, pose_length = 34,
|
| 575 |
+
audio short length = 36266.7 -> 36266
|
| 576 |
+
this error is fine.
|
| 577 |
+
"""
|
| 578 |
+
logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}")
|
| 579 |
+
|
| 580 |
+
n_filtered_out = defaultdict(int)
|
| 581 |
+
sample_pose_list = []
|
| 582 |
+
sample_audio_list = []
|
| 583 |
+
sample_facial_list = []
|
| 584 |
+
sample_shape_list = []
|
| 585 |
+
sample_word_list = []
|
| 586 |
+
sample_emo_list = []
|
| 587 |
+
sample_sem_list = []
|
| 588 |
+
sample_vid_list = []
|
| 589 |
+
sample_trans_list = []
|
| 590 |
+
|
| 591 |
+
for i in range(num_subdivision): # cut into around 2s chip, (self npose)
|
| 592 |
+
start_idx = clip_s_f_pose + i * self.args.stride
|
| 593 |
+
fin_idx = start_idx + cut_length
|
| 594 |
+
sample_pose = pose_each_file[start_idx:fin_idx]
|
| 595 |
+
sample_trans = trans_each_file[start_idx:fin_idx]
|
| 596 |
+
sample_shape = shape_each_file[start_idx:fin_idx]
|
| 597 |
+
# print(sample_pose.shape)
|
| 598 |
+
if self.args.audio_rep is not None:
|
| 599 |
+
audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps)
|
| 600 |
+
audio_end = audio_start + audio_short_length
|
| 601 |
+
sample_audio = audio_each_file[audio_start:audio_end]
|
| 602 |
+
else:
|
| 603 |
+
sample_audio = np.array([-1])
|
| 604 |
+
sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1])
|
| 605 |
+
sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1])
|
| 606 |
+
sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1])
|
| 607 |
+
sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1])
|
| 608 |
+
sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
|
| 609 |
+
|
| 610 |
+
if sample_pose.any() != None:
|
| 611 |
+
# filtering motion skeleton data
|
| 612 |
+
sample_pose, filtering_message = MotionPreprocessor(sample_pose).get()
|
| 613 |
+
is_correct_motion = (sample_pose != [])
|
| 614 |
+
if is_correct_motion or disable_filtering:
|
| 615 |
+
sample_pose_list.append(sample_pose)
|
| 616 |
+
sample_audio_list.append(sample_audio)
|
| 617 |
+
sample_facial_list.append(sample_facial)
|
| 618 |
+
sample_shape_list.append(sample_shape)
|
| 619 |
+
sample_word_list.append(sample_word)
|
| 620 |
+
sample_vid_list.append(sample_vid)
|
| 621 |
+
sample_emo_list.append(sample_emo)
|
| 622 |
+
sample_sem_list.append(sample_sem)
|
| 623 |
+
sample_trans_list.append(sample_trans)
|
| 624 |
+
else:
|
| 625 |
+
n_filtered_out[filtering_message] += 1
|
| 626 |
+
|
| 627 |
+
if len(sample_pose_list) > 0:
|
| 628 |
+
with dst_lmdb_env.begin(write=True) as txn:
|
| 629 |
+
for pose, audio, facial, shape, word, vid, emo, sem, trans in zip(
|
| 630 |
+
sample_pose_list,
|
| 631 |
+
sample_audio_list,
|
| 632 |
+
sample_facial_list,
|
| 633 |
+
sample_shape_list,
|
| 634 |
+
sample_word_list,
|
| 635 |
+
sample_vid_list,
|
| 636 |
+
sample_emo_list,
|
| 637 |
+
sample_sem_list,
|
| 638 |
+
sample_trans_list,):
|
| 639 |
+
k = "{:005}".format(self.n_out_samples).encode("ascii")
|
| 640 |
+
v = [pose, audio, facial, shape, word, emo, sem, vid, trans]
|
| 641 |
+
v = pyarrow.serialize(v).to_buffer()
|
| 642 |
+
txn.put(k, v)
|
| 643 |
+
self.n_out_samples += 1
|
| 644 |
+
return n_filtered_out
|
| 645 |
+
|
| 646 |
+
def __getitem__(self, idx):
|
| 647 |
+
with self.lmdb_env.begin(write=False) as txn:
|
| 648 |
+
key = "{:005}".format(idx).encode("ascii")
|
| 649 |
+
sample = txn.get(key)
|
| 650 |
+
sample = pyarrow.deserialize(sample)
|
| 651 |
+
tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample
|
| 652 |
+
#print(in_shape)
|
| 653 |
+
#vid = torch.from_numpy(vid).int()
|
| 654 |
+
emo = torch.from_numpy(emo).int()
|
| 655 |
+
sem = torch.from_numpy(sem).float()
|
| 656 |
+
in_audio = torch.from_numpy(in_audio).float()
|
| 657 |
+
in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int()
|
| 658 |
+
if self.loader_type == "test":
|
| 659 |
+
tar_pose = torch.from_numpy(tar_pose).float()
|
| 660 |
+
trans = torch.from_numpy(trans).float()
|
| 661 |
+
in_facial = torch.from_numpy(in_facial).float()
|
| 662 |
+
vid = torch.from_numpy(vid).float()
|
| 663 |
+
in_shape = torch.from_numpy(in_shape).float()
|
| 664 |
+
else:
|
| 665 |
+
in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 666 |
+
trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
|
| 667 |
+
vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
|
| 668 |
+
tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float()
|
| 669 |
+
in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float()
|
| 670 |
+
return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans}
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
class MotionPreprocessor:
|
| 674 |
+
def __init__(self, skeletons):
|
| 675 |
+
self.skeletons = skeletons
|
| 676 |
+
#self.mean_pose = mean_pose
|
| 677 |
+
self.filtering_message = "PASS"
|
| 678 |
+
|
| 679 |
+
def get(self):
|
| 680 |
+
assert (self.skeletons is not None)
|
| 681 |
+
|
| 682 |
+
# filtering
|
| 683 |
+
if self.skeletons != []:
|
| 684 |
+
if self.check_pose_diff():
|
| 685 |
+
self.skeletons = []
|
| 686 |
+
self.filtering_message = "pose"
|
| 687 |
+
# elif self.check_spine_angle():
|
| 688 |
+
# self.skeletons = []
|
| 689 |
+
# self.filtering_message = "spine angle"
|
| 690 |
+
# elif self.check_static_motion():
|
| 691 |
+
# self.skeletons = []
|
| 692 |
+
# self.filtering_message = "motion"
|
| 693 |
+
|
| 694 |
+
# if self.skeletons != []:
|
| 695 |
+
# self.skeletons = self.skeletons.tolist()
|
| 696 |
+
# for i, frame in enumerate(self.skeletons):
|
| 697 |
+
# assert not np.isnan(self.skeletons[i]).any() # missing joints
|
| 698 |
+
|
| 699 |
+
return self.skeletons, self.filtering_message
|
| 700 |
+
|
| 701 |
+
def check_static_motion(self, verbose=True):
|
| 702 |
+
def get_variance(skeleton, joint_idx):
|
| 703 |
+
wrist_pos = skeleton[:, joint_idx]
|
| 704 |
+
variance = np.sum(np.var(wrist_pos, axis=0))
|
| 705 |
+
return variance
|
| 706 |
+
|
| 707 |
+
left_arm_var = get_variance(self.skeletons, 6)
|
| 708 |
+
right_arm_var = get_variance(self.skeletons, 9)
|
| 709 |
+
|
| 710 |
+
th = 0.0014 # exclude 13110
|
| 711 |
+
# th = 0.002 # exclude 16905
|
| 712 |
+
if left_arm_var < th and right_arm_var < th:
|
| 713 |
+
if verbose:
|
| 714 |
+
print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 715 |
+
return True
|
| 716 |
+
else:
|
| 717 |
+
if verbose:
|
| 718 |
+
print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var))
|
| 719 |
+
return False
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def check_pose_diff(self, verbose=False):
|
| 723 |
+
# diff = np.abs(self.skeletons - self.mean_pose) # 186*1
|
| 724 |
+
# diff = np.mean(diff)
|
| 725 |
+
|
| 726 |
+
# # th = 0.017
|
| 727 |
+
# th = 0.02 #0.02 # exclude 3594
|
| 728 |
+
# if diff < th:
|
| 729 |
+
# if verbose:
|
| 730 |
+
# print("skip - check_pose_diff {:.5f}".format(diff))
|
| 731 |
+
# return True
|
| 732 |
+
# # th = 3.5 #0.02 # exclude 3594
|
| 733 |
+
# # if 3.5 < diff < 5:
|
| 734 |
+
# # if verbose:
|
| 735 |
+
# # print("skip - check_pose_diff {:.5f}".format(diff))
|
| 736 |
+
# # return True
|
| 737 |
+
# else:
|
| 738 |
+
# if verbose:
|
| 739 |
+
# print("pass - check_pose_diff {:.5f}".format(diff))
|
| 740 |
+
return False
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def check_spine_angle(self, verbose=True):
|
| 744 |
+
def angle_between(v1, v2):
|
| 745 |
+
v1_u = v1 / np.linalg.norm(v1)
|
| 746 |
+
v2_u = v2 / np.linalg.norm(v2)
|
| 747 |
+
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
|
| 748 |
+
|
| 749 |
+
angles = []
|
| 750 |
+
for i in range(self.skeletons.shape[0]):
|
| 751 |
+
spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0]
|
| 752 |
+
angle = angle_between(spine_vec, [0, -1, 0])
|
| 753 |
+
angles.append(angle)
|
| 754 |
+
|
| 755 |
+
if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495
|
| 756 |
+
# if np.rad2deg(max(angles)) > 20: # exclude 8270
|
| 757 |
+
if verbose:
|
| 758 |
+
print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles)))
|
| 759 |
+
return True
|
| 760 |
+
else:
|
| 761 |
+
if verbose:
|
| 762 |
+
print("pass - check_spine_angle {:.5f}".format(max(angles)))
|
| 763 |
+
return False
|
dataloaders/build_vocab.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import lmdb
|
| 6 |
+
#import pyarrow
|
| 7 |
+
import fasttext
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from scipy import linalg
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Vocab:
|
| 13 |
+
PAD_token = 0
|
| 14 |
+
SOS_token = 1
|
| 15 |
+
EOS_token = 2
|
| 16 |
+
UNK_token = 3
|
| 17 |
+
|
| 18 |
+
def __init__(self, name, insert_default_tokens=True):
|
| 19 |
+
self.name = name
|
| 20 |
+
self.trimmed = False
|
| 21 |
+
self.word_embedding_weights = None
|
| 22 |
+
self.reset_dictionary(insert_default_tokens)
|
| 23 |
+
|
| 24 |
+
def reset_dictionary(self, insert_default_tokens=True):
|
| 25 |
+
self.word2index = {}
|
| 26 |
+
self.word2count = {}
|
| 27 |
+
if insert_default_tokens:
|
| 28 |
+
self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>",
|
| 29 |
+
self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"}
|
| 30 |
+
else:
|
| 31 |
+
self.index2word = {self.UNK_token: "<UNK>"}
|
| 32 |
+
self.n_words = len(self.index2word) # count default tokens
|
| 33 |
+
|
| 34 |
+
def index_word(self, word):
|
| 35 |
+
if word not in self.word2index:
|
| 36 |
+
self.word2index[word] = self.n_words
|
| 37 |
+
self.word2count[word] = 1
|
| 38 |
+
self.index2word[self.n_words] = word
|
| 39 |
+
self.n_words += 1
|
| 40 |
+
else:
|
| 41 |
+
self.word2count[word] += 1
|
| 42 |
+
|
| 43 |
+
def add_vocab(self, other_vocab):
|
| 44 |
+
for word, _ in other_vocab.word2count.items():
|
| 45 |
+
self.index_word(word)
|
| 46 |
+
|
| 47 |
+
# remove words below a certain count threshold
|
| 48 |
+
def trim(self, min_count):
|
| 49 |
+
if self.trimmed:
|
| 50 |
+
return
|
| 51 |
+
self.trimmed = True
|
| 52 |
+
|
| 53 |
+
keep_words = []
|
| 54 |
+
|
| 55 |
+
for k, v in self.word2count.items():
|
| 56 |
+
if v >= min_count:
|
| 57 |
+
keep_words.append(k)
|
| 58 |
+
|
| 59 |
+
print(' word trimming, kept %s / %s = %.4f' % (
|
| 60 |
+
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
|
| 61 |
+
))
|
| 62 |
+
|
| 63 |
+
# reinitialize dictionary
|
| 64 |
+
self.reset_dictionary()
|
| 65 |
+
for word in keep_words:
|
| 66 |
+
self.index_word(word)
|
| 67 |
+
|
| 68 |
+
def get_word_index(self, word):
|
| 69 |
+
if word in self.word2index:
|
| 70 |
+
return self.word2index[word]
|
| 71 |
+
else:
|
| 72 |
+
return self.UNK_token
|
| 73 |
+
|
| 74 |
+
def load_word_vectors(self, pretrained_path, embedding_dim=300):
|
| 75 |
+
print(" loading word vectors from '{}'...".format(pretrained_path))
|
| 76 |
+
|
| 77 |
+
# initialize embeddings to random values for special words
|
| 78 |
+
init_sd = 1 / np.sqrt(embedding_dim)
|
| 79 |
+
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
|
| 80 |
+
weights = weights.astype(np.float32)
|
| 81 |
+
|
| 82 |
+
# read word vectors
|
| 83 |
+
word_model = fasttext.load_model(pretrained_path)
|
| 84 |
+
for word, id in self.word2index.items():
|
| 85 |
+
vec = word_model.get_word_vector(word)
|
| 86 |
+
weights[id] = vec
|
| 87 |
+
self.word_embedding_weights = weights
|
| 88 |
+
|
| 89 |
+
def __get_embedding_weight(self, pretrained_path, embedding_dim=300):
|
| 90 |
+
""" function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """
|
| 91 |
+
print("Loading word embedding '{}'...".format(pretrained_path))
|
| 92 |
+
cache_path = pretrained_path
|
| 93 |
+
weights = None
|
| 94 |
+
|
| 95 |
+
# use cached file if it exists
|
| 96 |
+
if os.path.exists(cache_path): #
|
| 97 |
+
with open(cache_path, 'rb') as f:
|
| 98 |
+
print(' using cached result from {}'.format(cache_path))
|
| 99 |
+
weights = pickle.load(f)
|
| 100 |
+
if weights.shape != (self.n_words, embedding_dim):
|
| 101 |
+
logging.warning(' failed to load word embedding weights. reinitializing...')
|
| 102 |
+
weights = None
|
| 103 |
+
|
| 104 |
+
if weights is None:
|
| 105 |
+
# initialize embeddings to random values for special and OOV words
|
| 106 |
+
init_sd = 1 / np.sqrt(embedding_dim)
|
| 107 |
+
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
|
| 108 |
+
weights = weights.astype(np.float32)
|
| 109 |
+
|
| 110 |
+
with open(pretrained_path, encoding="utf-8", mode="r") as textFile:
|
| 111 |
+
num_embedded_words = 0
|
| 112 |
+
for line_raw in textFile:
|
| 113 |
+
# extract the word, and embeddings vector
|
| 114 |
+
line = line_raw.split()
|
| 115 |
+
try:
|
| 116 |
+
word, vector = (line[0], np.array(line[1:], dtype=np.float32))
|
| 117 |
+
# if word == 'love': # debugging
|
| 118 |
+
# print(word, vector)
|
| 119 |
+
|
| 120 |
+
# if it is in our vocab, then update the corresponding weights
|
| 121 |
+
id = self.word2index.get(word, None)
|
| 122 |
+
if id is not None:
|
| 123 |
+
weights[id] = vector
|
| 124 |
+
num_embedded_words += 1
|
| 125 |
+
except ValueError:
|
| 126 |
+
print(' parsing error at {}...'.format(line_raw[:50]))
|
| 127 |
+
continue
|
| 128 |
+
print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index)))
|
| 129 |
+
|
| 130 |
+
with open(cache_path, 'wb') as f:
|
| 131 |
+
pickle.dump(weights, f)
|
| 132 |
+
return weights
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None):
|
| 136 |
+
print(' building a language model...')
|
| 137 |
+
#if not os.path.exists(cache_path):
|
| 138 |
+
lang_model = Vocab(name)
|
| 139 |
+
print(' indexing words from {}'.format(data_path))
|
| 140 |
+
index_words_from_textgrid(lang_model, data_path)
|
| 141 |
+
|
| 142 |
+
if word_vec_path is not None:
|
| 143 |
+
lang_model.load_word_vectors(word_vec_path, feat_dim)
|
| 144 |
+
else:
|
| 145 |
+
print(' loaded from {}'.format(cache_path))
|
| 146 |
+
with open(cache_path, 'rb') as f:
|
| 147 |
+
lang_model = pickle.load(f)
|
| 148 |
+
if word_vec_path is None:
|
| 149 |
+
lang_model.word_embedding_weights = None
|
| 150 |
+
elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words:
|
| 151 |
+
logging.warning(' failed to load word embedding weights. check this')
|
| 152 |
+
assert False
|
| 153 |
+
|
| 154 |
+
with open(cache_path, 'wb') as f:
|
| 155 |
+
pickle.dump(lang_model, f)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
return lang_model
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def index_words(lang_model, data_path):
|
| 162 |
+
#index words form text
|
| 163 |
+
with open(data_path, "r") as f:
|
| 164 |
+
for line in f.readlines():
|
| 165 |
+
line = line.replace(",", " ")
|
| 166 |
+
line = line.replace(".", " ")
|
| 167 |
+
line = line.replace("?", " ")
|
| 168 |
+
line = line.replace("!", " ")
|
| 169 |
+
for word in line.split():
|
| 170 |
+
lang_model.index_word(word)
|
| 171 |
+
print(' indexed %d words' % lang_model.n_words)
|
| 172 |
+
|
| 173 |
+
def index_words_from_textgrid(lang_model, data_path):
|
| 174 |
+
import textgrid as tg
|
| 175 |
+
from tqdm import tqdm
|
| 176 |
+
#trainvaltest=os.listdir(data_path)
|
| 177 |
+
# for loadtype in trainvaltest:
|
| 178 |
+
# if "." in loadtype: continue #ignore .ipynb_checkpoints
|
| 179 |
+
texts = os.listdir(data_path+"/textgrid/")
|
| 180 |
+
#print(texts)
|
| 181 |
+
for textfile in tqdm(texts):
|
| 182 |
+
tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile)
|
| 183 |
+
for word in tgrid[0]:
|
| 184 |
+
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
|
| 185 |
+
word_n = word_n.replace(",", " ")
|
| 186 |
+
word_n = word_n.replace(".", " ")
|
| 187 |
+
word_n = word_n.replace("?", " ")
|
| 188 |
+
word_n = word_n.replace("!", " ")
|
| 189 |
+
#print(word_n)
|
| 190 |
+
lang_model.index_word(word_n)
|
| 191 |
+
print(' indexed %d words' % lang_model.n_words)
|
| 192 |
+
print(lang_model.word2index, lang_model.word2count)
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
# 11195 for all, 5793 for 4 speakers
|
| 196 |
+
# build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300)
|
| 197 |
+
build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300)
|
| 198 |
+
|
| 199 |
+
|
dataloaders/data_tools.py
ADDED
|
@@ -0,0 +1,1756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import lmdb
|
| 6 |
+
#import pyarrow
|
| 7 |
+
import fasttext
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from scipy import linalg
|
| 10 |
+
from .pymo.parsers import BVHParser
|
| 11 |
+
from .pymo.viz_tools import *
|
| 12 |
+
from .pymo.preprocessing import *
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# pose version fpsxx_trinity/japanese_joints(_xxx)
|
| 18 |
+
joints_list = {
|
| 19 |
+
"trinity_joints":{
|
| 20 |
+
'Hips': [6,6],
|
| 21 |
+
'Spine': [3,9],
|
| 22 |
+
'Spine1': [3,12],
|
| 23 |
+
'Spine2': [3,15],
|
| 24 |
+
'Spine3': [3,18],
|
| 25 |
+
'Neck': [3,21],
|
| 26 |
+
'Neck1': [3,24],
|
| 27 |
+
'Head': [3,27],
|
| 28 |
+
'RShoulder': [3,30],
|
| 29 |
+
'RArm': [3,33],
|
| 30 |
+
'RArm1': [3,36],
|
| 31 |
+
'RHand': [3,39],
|
| 32 |
+
'RHandT1': [3,42],
|
| 33 |
+
'RHandT2': [3,45],
|
| 34 |
+
'RHandT3': [3,48],
|
| 35 |
+
'RHandI1': [3,51],
|
| 36 |
+
'RHandI2': [3,54],
|
| 37 |
+
'RHandI3': [3,57],
|
| 38 |
+
'RHandM1': [3,60],
|
| 39 |
+
'RHandM2': [3,63],
|
| 40 |
+
'RHandM3': [3,66],
|
| 41 |
+
'RHandR1': [3,69],
|
| 42 |
+
'RHandR2': [3,72],
|
| 43 |
+
'RHandR3': [3,75],
|
| 44 |
+
'RHandP1': [3,78],
|
| 45 |
+
'RHandP2': [3,81],
|
| 46 |
+
'RHandP3': [3,84],
|
| 47 |
+
'LShoulder': [3,87],
|
| 48 |
+
'LArm': [3,90],
|
| 49 |
+
'LArm1': [3,93],
|
| 50 |
+
'LHand': [3,96],
|
| 51 |
+
'LHandT1': [3,99],
|
| 52 |
+
'LHandT2': [3,102],
|
| 53 |
+
'LHandT3': [3,105],
|
| 54 |
+
'LHandI1': [3,108],
|
| 55 |
+
'LHandI2': [3,111],
|
| 56 |
+
'LHandI3': [3,114],
|
| 57 |
+
'LHandM1': [3,117],
|
| 58 |
+
'LHandM2': [3,120],
|
| 59 |
+
'LHandM3': [3,123],
|
| 60 |
+
'LHandR1': [3,126],
|
| 61 |
+
'LHandR2': [3,129],
|
| 62 |
+
'LHandR3': [3,132],
|
| 63 |
+
'LHandP1': [3,135],
|
| 64 |
+
'LHandP2': [3,138],
|
| 65 |
+
'LHandP3': [3,141],
|
| 66 |
+
'RUpLeg': [3,144],
|
| 67 |
+
'RLeg': [3,147],
|
| 68 |
+
'RFoot': [3,150],
|
| 69 |
+
'RFootF': [3,153],
|
| 70 |
+
'RToeBase': [3,156],
|
| 71 |
+
'LUpLeg': [3,159],
|
| 72 |
+
'LLeg': [3,162],
|
| 73 |
+
'LFoot': [3,165],
|
| 74 |
+
'LFootF': [3,168],
|
| 75 |
+
'LToeBase': [3,171],},
|
| 76 |
+
"trinity_joints_123":{
|
| 77 |
+
'Spine': 3 ,
|
| 78 |
+
'Neck': 3 ,
|
| 79 |
+
'Neck1': 3 ,
|
| 80 |
+
'RShoulder': 3 ,
|
| 81 |
+
'RArm': 3 ,
|
| 82 |
+
'RArm1': 3 ,
|
| 83 |
+
'RHand': 3 ,
|
| 84 |
+
'RHandT1': 3 ,
|
| 85 |
+
'RHandT2': 3 ,
|
| 86 |
+
'RHandT3': 3 ,
|
| 87 |
+
'RHandI1': 3 ,
|
| 88 |
+
'RHandI2': 3 ,
|
| 89 |
+
'RHandI3': 3 ,
|
| 90 |
+
'RHandM1': 3 ,
|
| 91 |
+
'RHandM2': 3 ,
|
| 92 |
+
'RHandM3': 3 ,
|
| 93 |
+
'RHandR1': 3 ,
|
| 94 |
+
'RHandR2': 3 ,
|
| 95 |
+
'RHandR3': 3 ,
|
| 96 |
+
'RHandP1': 3 ,
|
| 97 |
+
'RHandP2': 3 ,
|
| 98 |
+
'RHandP3': 3 ,
|
| 99 |
+
'LShoulder': 3 ,
|
| 100 |
+
'LArm': 3 ,
|
| 101 |
+
'LArm1': 3 ,
|
| 102 |
+
'LHand': 3 ,
|
| 103 |
+
'LHandT1': 3 ,
|
| 104 |
+
'LHandT2': 3 ,
|
| 105 |
+
'LHandT3': 3 ,
|
| 106 |
+
'LHandI1': 3 ,
|
| 107 |
+
'LHandI2': 3 ,
|
| 108 |
+
'LHandI3': 3 ,
|
| 109 |
+
'LHandM1': 3 ,
|
| 110 |
+
'LHandM2': 3 ,
|
| 111 |
+
'LHandM3': 3 ,
|
| 112 |
+
'LHandR1': 3 ,
|
| 113 |
+
'LHandR2': 3 ,
|
| 114 |
+
'LHandR3': 3 ,
|
| 115 |
+
'LHandP1': 3 ,
|
| 116 |
+
'LHandP2': 3 ,
|
| 117 |
+
'LHandP3': 3 ,},
|
| 118 |
+
"trinity_joints_168":{
|
| 119 |
+
'Hips': 3 ,
|
| 120 |
+
'Spine': 3 ,
|
| 121 |
+
'Spine1': 3 ,
|
| 122 |
+
'Spine2': 3 ,
|
| 123 |
+
'Spine3': 3 ,
|
| 124 |
+
'Neck': 3 ,
|
| 125 |
+
'Neck1': 3 ,
|
| 126 |
+
'Head': 3 ,
|
| 127 |
+
'RShoulder': 3 ,
|
| 128 |
+
'RArm': 3 ,
|
| 129 |
+
'RArm1': 3 ,
|
| 130 |
+
'RHand': 3 ,
|
| 131 |
+
'RHandT1': 3 ,
|
| 132 |
+
'RHandT2': 3 ,
|
| 133 |
+
'RHandT3': 3 ,
|
| 134 |
+
'RHandI1': 3 ,
|
| 135 |
+
'RHandI2': 3 ,
|
| 136 |
+
'RHandI3': 3 ,
|
| 137 |
+
'RHandM1': 3 ,
|
| 138 |
+
'RHandM2': 3 ,
|
| 139 |
+
'RHandM3': 3 ,
|
| 140 |
+
'RHandR1': 3 ,
|
| 141 |
+
'RHandR2': 3 ,
|
| 142 |
+
'RHandR3': 3 ,
|
| 143 |
+
'RHandP1': 3 ,
|
| 144 |
+
'RHandP2': 3 ,
|
| 145 |
+
'RHandP3': 3 ,
|
| 146 |
+
'LShoulder': 3 ,
|
| 147 |
+
'LArm': 3 ,
|
| 148 |
+
'LArm1': 3 ,
|
| 149 |
+
'LHand': 3 ,
|
| 150 |
+
'LHandT1': 3 ,
|
| 151 |
+
'LHandT2': 3 ,
|
| 152 |
+
'LHandT3': 3 ,
|
| 153 |
+
'LHandI1': 3 ,
|
| 154 |
+
'LHandI2': 3 ,
|
| 155 |
+
'LHandI3': 3 ,
|
| 156 |
+
'LHandM1': 3 ,
|
| 157 |
+
'LHandM2': 3 ,
|
| 158 |
+
'LHandM3': 3 ,
|
| 159 |
+
'LHandR1': 3 ,
|
| 160 |
+
'LHandR2': 3 ,
|
| 161 |
+
'LHandR3': 3 ,
|
| 162 |
+
'LHandP1': 3 ,
|
| 163 |
+
'LHandP2': 3 ,
|
| 164 |
+
'LHandP3': 3 ,
|
| 165 |
+
'RUpLeg': 3 ,
|
| 166 |
+
'RLeg': 3 ,
|
| 167 |
+
'RFoot': 3 ,
|
| 168 |
+
'RFootF': 3 ,
|
| 169 |
+
'RToeBase': 3 ,
|
| 170 |
+
'LUpLeg': 3 ,
|
| 171 |
+
'LLeg': 3 ,
|
| 172 |
+
'LFoot': 3 ,
|
| 173 |
+
'LFootF': 3 ,
|
| 174 |
+
'LToeBase': 3 ,},
|
| 175 |
+
"trinity_joints_138":{
|
| 176 |
+
"Hips": 3 ,
|
| 177 |
+
'Spine': 3 ,
|
| 178 |
+
'Spine1': 3 ,
|
| 179 |
+
'Spine2': 3 ,
|
| 180 |
+
'Spine3': 3 ,
|
| 181 |
+
'Neck': 3 ,
|
| 182 |
+
'Neck1': 3 ,
|
| 183 |
+
'Head': 3 ,
|
| 184 |
+
'RShoulder': 3 ,
|
| 185 |
+
'RArm': 3 ,
|
| 186 |
+
'RArm1': 3 ,
|
| 187 |
+
'RHand': 3 ,
|
| 188 |
+
'RHandT1': 3 ,
|
| 189 |
+
'RHandT2': 3 ,
|
| 190 |
+
'RHandT3': 3 ,
|
| 191 |
+
'RHandI1': 3 ,
|
| 192 |
+
'RHandI2': 3 ,
|
| 193 |
+
'RHandI3': 3 ,
|
| 194 |
+
'RHandM1': 3 ,
|
| 195 |
+
'RHandM2': 3 ,
|
| 196 |
+
'RHandM3': 3 ,
|
| 197 |
+
'RHandR1': 3 ,
|
| 198 |
+
'RHandR2': 3 ,
|
| 199 |
+
'RHandR3': 3 ,
|
| 200 |
+
'RHandP1': 3 ,
|
| 201 |
+
'RHandP2': 3 ,
|
| 202 |
+
'RHandP3': 3 ,
|
| 203 |
+
'LShoulder': 3 ,
|
| 204 |
+
'LArm': 3 ,
|
| 205 |
+
'LArm1': 3 ,
|
| 206 |
+
'LHand': 3 ,
|
| 207 |
+
'LHandT1': 3 ,
|
| 208 |
+
'LHandT2': 3 ,
|
| 209 |
+
'LHandT3': 3 ,
|
| 210 |
+
'LHandI1': 3 ,
|
| 211 |
+
'LHandI2': 3 ,
|
| 212 |
+
'LHandI3': 3 ,
|
| 213 |
+
'LHandM1': 3 ,
|
| 214 |
+
'LHandM2': 3 ,
|
| 215 |
+
'LHandM3': 3 ,
|
| 216 |
+
'LHandR1': 3 ,
|
| 217 |
+
'LHandR2': 3 ,
|
| 218 |
+
'LHandR3': 3 ,
|
| 219 |
+
'LHandP1': 3 ,
|
| 220 |
+
'LHandP2': 3 ,
|
| 221 |
+
'LHandP3': 3 ,},
|
| 222 |
+
"beat_smplx_joints": {
|
| 223 |
+
'pelvis': [3,3],
|
| 224 |
+
'left_hip': [3,6],
|
| 225 |
+
'right_hip': [3,9],
|
| 226 |
+
'spine1': [3,12],
|
| 227 |
+
'left_knee': [3,15],
|
| 228 |
+
'right_knee': [3,18],
|
| 229 |
+
'spine2': [3,21],
|
| 230 |
+
'left_ankle': [3,24],
|
| 231 |
+
'right_ankle': [3,27],
|
| 232 |
+
|
| 233 |
+
'spine3': [3,30],
|
| 234 |
+
'left_foot': [3,33],
|
| 235 |
+
'right_foot': [3,36],
|
| 236 |
+
'neck': [3,39],
|
| 237 |
+
'left_collar': [3,42],
|
| 238 |
+
'right_collar': [3,45],
|
| 239 |
+
'head': [3,48],
|
| 240 |
+
'left_shoulder': [3,51],
|
| 241 |
+
|
| 242 |
+
'right_shoulder': [3,54],
|
| 243 |
+
'left_elbow': [3,57],
|
| 244 |
+
'right_elbow': [3,60],
|
| 245 |
+
'left_wrist': [3,63],
|
| 246 |
+
'right_wrist': [3,66],
|
| 247 |
+
|
| 248 |
+
'jaw': [3,69],
|
| 249 |
+
'left_eye_smplhf': [3,72],
|
| 250 |
+
'right_eye_smplhf': [3,75],
|
| 251 |
+
'left_index1': [3,78],
|
| 252 |
+
'left_index2': [3,81],
|
| 253 |
+
|
| 254 |
+
'left_index3': [3,84],
|
| 255 |
+
'left_middle1': [3,87],
|
| 256 |
+
'left_middle2': [3,90],
|
| 257 |
+
'left_middle3': [3,93],
|
| 258 |
+
'left_pinky1': [3,96],
|
| 259 |
+
|
| 260 |
+
'left_pinky2': [3,99],
|
| 261 |
+
'left_pinky3': [3,102],
|
| 262 |
+
'left_ring1': [3,105],
|
| 263 |
+
'left_ring2': [3,108],
|
| 264 |
+
|
| 265 |
+
'left_ring3': [3,111],
|
| 266 |
+
'left_thumb1': [3,114],
|
| 267 |
+
'left_thumb2': [3,117],
|
| 268 |
+
'left_thumb3': [3,120],
|
| 269 |
+
'right_index1': [3,123],
|
| 270 |
+
'right_index2': [3,126],
|
| 271 |
+
'right_index3': [3,129],
|
| 272 |
+
'right_middle1': [3,132],
|
| 273 |
+
|
| 274 |
+
'right_middle2': [3,135],
|
| 275 |
+
'right_middle3': [3,138],
|
| 276 |
+
'right_pinky1': [3,141],
|
| 277 |
+
'right_pinky2': [3,144],
|
| 278 |
+
'right_pinky3': [3,147],
|
| 279 |
+
|
| 280 |
+
'right_ring1': [3,150],
|
| 281 |
+
'right_ring2': [3,153],
|
| 282 |
+
'right_ring3': [3,156],
|
| 283 |
+
'right_thumb1': [3,159],
|
| 284 |
+
'right_thumb2': [3,162],
|
| 285 |
+
'right_thumb3': [3,165],
|
| 286 |
+
|
| 287 |
+
# 'nose': [3,168],
|
| 288 |
+
# 'right_eye': [3,171],
|
| 289 |
+
# 'left_eye': [3,174],
|
| 290 |
+
# 'right_ear': [3,177],
|
| 291 |
+
|
| 292 |
+
# 'left_ear': [3,180],
|
| 293 |
+
# 'left_big_toe': [3,183],
|
| 294 |
+
# 'left_small_toe': [3,186],
|
| 295 |
+
# 'left_heel': [3,189],
|
| 296 |
+
|
| 297 |
+
# 'right_big_toe': [3,192],
|
| 298 |
+
# 'right_small_toe': [3,195],
|
| 299 |
+
# 'right_heel': [3,198],
|
| 300 |
+
# 'left_thumb': [3,201],
|
| 301 |
+
# 'left_index': [3,204],
|
| 302 |
+
# 'left_middle': [3,207],
|
| 303 |
+
|
| 304 |
+
# 'left_ring': [3,210],
|
| 305 |
+
# 'left_pinky': [3,213],
|
| 306 |
+
# 'right_thumb': [3,216],
|
| 307 |
+
# 'right_index': [3,219],
|
| 308 |
+
# 'right_middle': [3,222],
|
| 309 |
+
# 'right_ring': [3,225],
|
| 310 |
+
|
| 311 |
+
# 'right_pinky': [3,228],
|
| 312 |
+
# 'right_eye_brow1': [3,231],
|
| 313 |
+
# 'right_eye_brow2': [3,234],
|
| 314 |
+
# 'right_eye_brow3': [3,237],
|
| 315 |
+
|
| 316 |
+
# 'right_eye_brow4': [3,240],
|
| 317 |
+
# 'right_eye_brow5': [3,243],
|
| 318 |
+
# 'left_eye_brow5': [3,246],
|
| 319 |
+
# 'left_eye_brow4': [3,249],
|
| 320 |
+
|
| 321 |
+
# 'left_eye_brow3': [3,252],
|
| 322 |
+
# 'left_eye_brow2': [3,255],
|
| 323 |
+
# 'left_eye_brow1': [3,258],
|
| 324 |
+
# 'nose1': [3,261],
|
| 325 |
+
# 'nose2': [3,264],
|
| 326 |
+
# 'nose3': [3,267],
|
| 327 |
+
|
| 328 |
+
# 'nose4': [3,270],
|
| 329 |
+
# 'right_nose_2': [3,273],
|
| 330 |
+
# 'right_nose_1': [3,276],
|
| 331 |
+
# 'nose_middle': [3,279],
|
| 332 |
+
# 'left_nose_1': [3,282],
|
| 333 |
+
# 'left_nose_2': [3,285],
|
| 334 |
+
|
| 335 |
+
# 'right_eye1': [3,288],
|
| 336 |
+
# 'right_eye2': [3,291],
|
| 337 |
+
# 'right_eye3': [3,294],
|
| 338 |
+
# 'right_eye4': [3,297],
|
| 339 |
+
|
| 340 |
+
# 'right_eye5': [3,300],
|
| 341 |
+
# 'right_eye6': [3,303],
|
| 342 |
+
# 'left_eye4': [3,306],
|
| 343 |
+
# 'left_eye3': [3,309],
|
| 344 |
+
|
| 345 |
+
# 'left_eye2': [3,312],
|
| 346 |
+
# 'left_eye1': [3,315],
|
| 347 |
+
# 'left_eye6': [3,318],
|
| 348 |
+
# 'left_eye5': [3,321],
|
| 349 |
+
# 'right_mouth_1': [3,324],
|
| 350 |
+
# 'right_mouth_2': [3,327],
|
| 351 |
+
# 'right_mouth_3': [3,330],
|
| 352 |
+
# 'mouth_top': [3,333],
|
| 353 |
+
# 'left_mouth_3': [3,336],
|
| 354 |
+
# 'left_mouth_2': [3,339],
|
| 355 |
+
# 'left_mouth_1': [3,342],
|
| 356 |
+
# 'left_mouth_5': [3,345],
|
| 357 |
+
# 'left_mouth_4': [3,348],
|
| 358 |
+
# 'mouth_bottom': [3,351],
|
| 359 |
+
# 'right_mouth_4': [3,354],
|
| 360 |
+
# 'right_mouth_5': [3,357],
|
| 361 |
+
# 'right_lip_1': [3,360],
|
| 362 |
+
# 'right_lip_2': [3,363],
|
| 363 |
+
# 'lip_top': [3,366],
|
| 364 |
+
# 'left_lip_2': [3,369],
|
| 365 |
+
|
| 366 |
+
# 'left_lip_1': [3,372],
|
| 367 |
+
# 'left_lip_3': [3,375],
|
| 368 |
+
# 'lip_bottom': [3,378],
|
| 369 |
+
# 'right_lip_3': [3,381],
|
| 370 |
+
# 'right_contour_1': [3,384],
|
| 371 |
+
# 'right_contour_2': [3,387],
|
| 372 |
+
# 'right_contour_3': [3,390],
|
| 373 |
+
# 'right_contour_4': [3,393],
|
| 374 |
+
# 'right_contour_5': [3,396],
|
| 375 |
+
# 'right_contour_6': [3,399],
|
| 376 |
+
# 'right_contour_7': [3,402],
|
| 377 |
+
# 'right_contour_8': [3,405],
|
| 378 |
+
# 'contour_middle': [3,408],
|
| 379 |
+
# 'left_contour_8': [3,411],
|
| 380 |
+
# 'left_contour_7': [3,414],
|
| 381 |
+
# 'left_contour_6': [3,417],
|
| 382 |
+
# 'left_contour_5': [3,420],
|
| 383 |
+
# 'left_contour_4': [3,423],
|
| 384 |
+
# 'left_contour_3': [3,426],
|
| 385 |
+
# 'left_contour_2': [3,429],
|
| 386 |
+
# 'left_contour_1': [3,432],
|
| 387 |
+
},
|
| 388 |
+
|
| 389 |
+
"beat_smplx_no_eyes": {
|
| 390 |
+
"pelvis":3,
|
| 391 |
+
"left_hip":3,
|
| 392 |
+
"right_hip":3,
|
| 393 |
+
"spine1":3,
|
| 394 |
+
"left_knee":3,
|
| 395 |
+
"right_knee":3,
|
| 396 |
+
"spine2":3,
|
| 397 |
+
"left_ankle":3,
|
| 398 |
+
"right_ankle":3,
|
| 399 |
+
"spine3":3,
|
| 400 |
+
"left_foot":3,
|
| 401 |
+
"right_foot":3,
|
| 402 |
+
"neck":3,
|
| 403 |
+
"left_collar":3,
|
| 404 |
+
"right_collar":3,
|
| 405 |
+
"head":3,
|
| 406 |
+
"left_shoulder":3,
|
| 407 |
+
"right_shoulder":3,
|
| 408 |
+
"left_elbow":3,
|
| 409 |
+
"right_elbow":3,
|
| 410 |
+
"left_wrist":3,
|
| 411 |
+
"right_wrist":3,
|
| 412 |
+
"jaw":3,
|
| 413 |
+
# "left_eye_smplhf":3,
|
| 414 |
+
# "right_eye_smplhf":3,
|
| 415 |
+
"left_index1":3,
|
| 416 |
+
"left_index2":3,
|
| 417 |
+
"left_index3":3,
|
| 418 |
+
"left_middle1":3,
|
| 419 |
+
"left_middle2":3,
|
| 420 |
+
"left_middle3":3,
|
| 421 |
+
"left_pinky1":3,
|
| 422 |
+
"left_pinky2":3,
|
| 423 |
+
"left_pinky3":3,
|
| 424 |
+
"left_ring1":3,
|
| 425 |
+
"left_ring2":3,
|
| 426 |
+
"left_ring3":3,
|
| 427 |
+
"left_thumb1":3,
|
| 428 |
+
"left_thumb2":3,
|
| 429 |
+
"left_thumb3":3,
|
| 430 |
+
"right_index1":3,
|
| 431 |
+
"right_index2":3,
|
| 432 |
+
"right_index3":3,
|
| 433 |
+
"right_middle1":3,
|
| 434 |
+
"right_middle2":3,
|
| 435 |
+
"right_middle3":3,
|
| 436 |
+
"right_pinky1":3,
|
| 437 |
+
"right_pinky2":3,
|
| 438 |
+
"right_pinky3":3,
|
| 439 |
+
"right_ring1":3,
|
| 440 |
+
"right_ring2":3,
|
| 441 |
+
"right_ring3":3,
|
| 442 |
+
"right_thumb1":3,
|
| 443 |
+
"right_thumb2":3,
|
| 444 |
+
"right_thumb3":3,
|
| 445 |
+
},
|
| 446 |
+
|
| 447 |
+
"beat_smplx_full": {
|
| 448 |
+
"pelvis":3,
|
| 449 |
+
"left_hip":3,
|
| 450 |
+
"right_hip":3,
|
| 451 |
+
"spine1":3,
|
| 452 |
+
"left_knee":3,
|
| 453 |
+
"right_knee":3,
|
| 454 |
+
"spine2":3,
|
| 455 |
+
"left_ankle":3,
|
| 456 |
+
"right_ankle":3,
|
| 457 |
+
"spine3":3,
|
| 458 |
+
"left_foot":3,
|
| 459 |
+
"right_foot":3,
|
| 460 |
+
"neck":3,
|
| 461 |
+
"left_collar":3,
|
| 462 |
+
"right_collar":3,
|
| 463 |
+
"head":3,
|
| 464 |
+
"left_shoulder":3,
|
| 465 |
+
"right_shoulder":3,
|
| 466 |
+
"left_elbow":3,
|
| 467 |
+
"right_elbow":3,
|
| 468 |
+
"left_wrist":3,
|
| 469 |
+
"right_wrist":3,
|
| 470 |
+
"jaw":3,
|
| 471 |
+
"left_eye_smplhf":3,
|
| 472 |
+
"right_eye_smplhf":3,
|
| 473 |
+
"left_index1":3,
|
| 474 |
+
"left_index2":3,
|
| 475 |
+
"left_index3":3,
|
| 476 |
+
"left_middle1":3,
|
| 477 |
+
"left_middle2":3,
|
| 478 |
+
"left_middle3":3,
|
| 479 |
+
"left_pinky1":3,
|
| 480 |
+
"left_pinky2":3,
|
| 481 |
+
"left_pinky3":3,
|
| 482 |
+
"left_ring1":3,
|
| 483 |
+
"left_ring2":3,
|
| 484 |
+
"left_ring3":3,
|
| 485 |
+
"left_thumb1":3,
|
| 486 |
+
"left_thumb2":3,
|
| 487 |
+
"left_thumb3":3,
|
| 488 |
+
"right_index1":3,
|
| 489 |
+
"right_index2":3,
|
| 490 |
+
"right_index3":3,
|
| 491 |
+
"right_middle1":3,
|
| 492 |
+
"right_middle2":3,
|
| 493 |
+
"right_middle3":3,
|
| 494 |
+
"right_pinky1":3,
|
| 495 |
+
"right_pinky2":3,
|
| 496 |
+
"right_pinky3":3,
|
| 497 |
+
"right_ring1":3,
|
| 498 |
+
"right_ring2":3,
|
| 499 |
+
"right_ring3":3,
|
| 500 |
+
"right_thumb1":3,
|
| 501 |
+
"right_thumb2":3,
|
| 502 |
+
"right_thumb3":3,
|
| 503 |
+
},
|
| 504 |
+
|
| 505 |
+
"beat_smplx_upall": {
|
| 506 |
+
# "pelvis":3,
|
| 507 |
+
# "left_hip":3,
|
| 508 |
+
# "right_hip":3,
|
| 509 |
+
"spine1":3,
|
| 510 |
+
# "left_knee":3,
|
| 511 |
+
# "right_knee":3,
|
| 512 |
+
"spine2":3,
|
| 513 |
+
# "left_ankle":3,
|
| 514 |
+
# "right_ankle":3,
|
| 515 |
+
"spine3":3,
|
| 516 |
+
# "left_foot":3,
|
| 517 |
+
# "right_foot":3,
|
| 518 |
+
"neck":3,
|
| 519 |
+
"left_collar":3,
|
| 520 |
+
"right_collar":3,
|
| 521 |
+
"head":3,
|
| 522 |
+
"left_shoulder":3,
|
| 523 |
+
"right_shoulder":3,
|
| 524 |
+
"left_elbow":3,
|
| 525 |
+
"right_elbow":3,
|
| 526 |
+
"left_wrist":3,
|
| 527 |
+
"right_wrist":3,
|
| 528 |
+
# "jaw":3,
|
| 529 |
+
# "left_eye_smplhf":3,
|
| 530 |
+
# "right_eye_smplhf":3,
|
| 531 |
+
"left_index1":3,
|
| 532 |
+
"left_index2":3,
|
| 533 |
+
"left_index3":3,
|
| 534 |
+
"left_middle1":3,
|
| 535 |
+
"left_middle2":3,
|
| 536 |
+
"left_middle3":3,
|
| 537 |
+
"left_pinky1":3,
|
| 538 |
+
"left_pinky2":3,
|
| 539 |
+
"left_pinky3":3,
|
| 540 |
+
"left_ring1":3,
|
| 541 |
+
"left_ring2":3,
|
| 542 |
+
"left_ring3":3,
|
| 543 |
+
"left_thumb1":3,
|
| 544 |
+
"left_thumb2":3,
|
| 545 |
+
"left_thumb3":3,
|
| 546 |
+
"right_index1":3,
|
| 547 |
+
"right_index2":3,
|
| 548 |
+
"right_index3":3,
|
| 549 |
+
"right_middle1":3,
|
| 550 |
+
"right_middle2":3,
|
| 551 |
+
"right_middle3":3,
|
| 552 |
+
"right_pinky1":3,
|
| 553 |
+
"right_pinky2":3,
|
| 554 |
+
"right_pinky3":3,
|
| 555 |
+
"right_ring1":3,
|
| 556 |
+
"right_ring2":3,
|
| 557 |
+
"right_ring3":3,
|
| 558 |
+
"right_thumb1":3,
|
| 559 |
+
"right_thumb2":3,
|
| 560 |
+
"right_thumb3":3,
|
| 561 |
+
},
|
| 562 |
+
|
| 563 |
+
"beat_smplx_upper": {
|
| 564 |
+
#"pelvis":3,
|
| 565 |
+
# "left_hip":3,
|
| 566 |
+
# "right_hip":3,
|
| 567 |
+
"spine1":3,
|
| 568 |
+
# "left_knee":3,
|
| 569 |
+
# "right_knee":3,
|
| 570 |
+
"spine2":3,
|
| 571 |
+
# "left_ankle":3,
|
| 572 |
+
# "right_ankle":3,
|
| 573 |
+
"spine3":3,
|
| 574 |
+
# "left_foot":3,
|
| 575 |
+
# "right_foot":3,
|
| 576 |
+
"neck":3,
|
| 577 |
+
"left_collar":3,
|
| 578 |
+
"right_collar":3,
|
| 579 |
+
"head":3,
|
| 580 |
+
"left_shoulder":3,
|
| 581 |
+
"right_shoulder":3,
|
| 582 |
+
"left_elbow":3,
|
| 583 |
+
"right_elbow":3,
|
| 584 |
+
"left_wrist":3,
|
| 585 |
+
"right_wrist":3,
|
| 586 |
+
# "jaw":3,
|
| 587 |
+
# "left_eye_smplhf":3,
|
| 588 |
+
# "right_eye_smplhf":3,
|
| 589 |
+
# "left_index1":3,
|
| 590 |
+
# "left_index2":3,
|
| 591 |
+
# "left_index3":3,
|
| 592 |
+
# "left_middle1":3,
|
| 593 |
+
# "left_middle2":3,
|
| 594 |
+
# "left_middle3":3,
|
| 595 |
+
# "left_pinky1":3,
|
| 596 |
+
# "left_pinky2":3,
|
| 597 |
+
# "left_pinky3":3,
|
| 598 |
+
# "left_ring1":3,
|
| 599 |
+
# "left_ring2":3,
|
| 600 |
+
# "left_ring3":3,
|
| 601 |
+
# "left_thumb1":3,
|
| 602 |
+
# "left_thumb2":3,
|
| 603 |
+
# "left_thumb3":3,
|
| 604 |
+
# "right_index1":3,
|
| 605 |
+
# "right_index2":3,
|
| 606 |
+
# "right_index3":3,
|
| 607 |
+
# "right_middle1":3,
|
| 608 |
+
# "right_middle2":3,
|
| 609 |
+
# "right_middle3":3,
|
| 610 |
+
# "right_pinky1":3,
|
| 611 |
+
# "right_pinky2":3,
|
| 612 |
+
# "right_pinky3":3,
|
| 613 |
+
# "right_ring1":3,
|
| 614 |
+
# "right_ring2":3,
|
| 615 |
+
# "right_ring3":3,
|
| 616 |
+
# "right_thumb1":3,
|
| 617 |
+
# "right_thumb2":3,
|
| 618 |
+
# "right_thumb3":3,
|
| 619 |
+
},
|
| 620 |
+
|
| 621 |
+
"beat_smplx_hands": {
|
| 622 |
+
#"pelvis":3,
|
| 623 |
+
# "left_hip":3,
|
| 624 |
+
# "right_hip":3,
|
| 625 |
+
# "spine1":3,
|
| 626 |
+
# "left_knee":3,
|
| 627 |
+
# "right_knee":3,
|
| 628 |
+
# "spine2":3,
|
| 629 |
+
# "left_ankle":3,
|
| 630 |
+
# "right_ankle":3,
|
| 631 |
+
# "spine3":3,
|
| 632 |
+
# "left_foot":3,
|
| 633 |
+
# "right_foot":3,
|
| 634 |
+
# "neck":3,
|
| 635 |
+
# "left_collar":3,
|
| 636 |
+
# "right_collar":3,
|
| 637 |
+
# "head":3,
|
| 638 |
+
# "left_shoulder":3,
|
| 639 |
+
# "right_shoulder":3,
|
| 640 |
+
# "left_elbow":3,
|
| 641 |
+
# "right_elbow":3,
|
| 642 |
+
# "left_wrist":3,
|
| 643 |
+
# "right_wrist":3,
|
| 644 |
+
# "jaw":3,
|
| 645 |
+
# "left_eye_smplhf":3,
|
| 646 |
+
# "right_eye_smplhf":3,
|
| 647 |
+
"left_index1":3,
|
| 648 |
+
"left_index2":3,
|
| 649 |
+
"left_index3":3,
|
| 650 |
+
"left_middle1":3,
|
| 651 |
+
"left_middle2":3,
|
| 652 |
+
"left_middle3":3,
|
| 653 |
+
"left_pinky1":3,
|
| 654 |
+
"left_pinky2":3,
|
| 655 |
+
"left_pinky3":3,
|
| 656 |
+
"left_ring1":3,
|
| 657 |
+
"left_ring2":3,
|
| 658 |
+
"left_ring3":3,
|
| 659 |
+
"left_thumb1":3,
|
| 660 |
+
"left_thumb2":3,
|
| 661 |
+
"left_thumb3":3,
|
| 662 |
+
"right_index1":3,
|
| 663 |
+
"right_index2":3,
|
| 664 |
+
"right_index3":3,
|
| 665 |
+
"right_middle1":3,
|
| 666 |
+
"right_middle2":3,
|
| 667 |
+
"right_middle3":3,
|
| 668 |
+
"right_pinky1":3,
|
| 669 |
+
"right_pinky2":3,
|
| 670 |
+
"right_pinky3":3,
|
| 671 |
+
"right_ring1":3,
|
| 672 |
+
"right_ring2":3,
|
| 673 |
+
"right_ring3":3,
|
| 674 |
+
"right_thumb1":3,
|
| 675 |
+
"right_thumb2":3,
|
| 676 |
+
"right_thumb3":3,
|
| 677 |
+
},
|
| 678 |
+
|
| 679 |
+
"beat_smplx_lower": {
|
| 680 |
+
"pelvis":3,
|
| 681 |
+
"left_hip":3,
|
| 682 |
+
"right_hip":3,
|
| 683 |
+
# "spine1":3,
|
| 684 |
+
"left_knee":3,
|
| 685 |
+
"right_knee":3,
|
| 686 |
+
# "spine2":3,
|
| 687 |
+
"left_ankle":3,
|
| 688 |
+
"right_ankle":3,
|
| 689 |
+
# "spine3":3,
|
| 690 |
+
"left_foot":3,
|
| 691 |
+
"right_foot":3,
|
| 692 |
+
# "neck":3,
|
| 693 |
+
# "left_collar":3,
|
| 694 |
+
# "right_collar":3,
|
| 695 |
+
# "head":3,
|
| 696 |
+
# "left_shoulder":3,
|
| 697 |
+
# "right_shoulder":3,
|
| 698 |
+
# "left_elbow":3,
|
| 699 |
+
# "right_elbow":3,
|
| 700 |
+
# "left_wrist":3,
|
| 701 |
+
# "right_wrist":3,
|
| 702 |
+
# "jaw":3,
|
| 703 |
+
# "left_eye_smplhf":3,
|
| 704 |
+
# "right_eye_smplhf":3,
|
| 705 |
+
# "left_index1":3,
|
| 706 |
+
# "left_index2":3,
|
| 707 |
+
# "left_index3":3,
|
| 708 |
+
# "left_middle1":3,
|
| 709 |
+
# "left_middle2":3,
|
| 710 |
+
# "left_middle3":3,
|
| 711 |
+
# "left_pinky1":3,
|
| 712 |
+
# "left_pinky2":3,
|
| 713 |
+
# "left_pinky3":3,
|
| 714 |
+
# "left_ring1":3,
|
| 715 |
+
# "left_ring2":3,
|
| 716 |
+
# "left_ring3":3,
|
| 717 |
+
# "left_thumb1":3,
|
| 718 |
+
# "left_thumb2":3,
|
| 719 |
+
# "left_thumb3":3,
|
| 720 |
+
# "right_index1":3,
|
| 721 |
+
# "right_index2":3,
|
| 722 |
+
# "right_index3":3,
|
| 723 |
+
# "right_middle1":3,
|
| 724 |
+
# "right_middle2":3,
|
| 725 |
+
# "right_middle3":3,
|
| 726 |
+
# "right_pinky1":3,
|
| 727 |
+
# "right_pinky2":3,
|
| 728 |
+
# "right_pinky3":3,
|
| 729 |
+
# "right_ring1":3,
|
| 730 |
+
# "right_ring2":3,
|
| 731 |
+
# "right_ring3":3,
|
| 732 |
+
# "right_thumb1":3,
|
| 733 |
+
# "right_thumb2":3,
|
| 734 |
+
# "right_thumb3":3,
|
| 735 |
+
},
|
| 736 |
+
|
| 737 |
+
"beat_smplx_face": {
|
| 738 |
+
# "pelvis":3,
|
| 739 |
+
# "left_hip":3,
|
| 740 |
+
# "right_hip":3,
|
| 741 |
+
# # "spine1":3,
|
| 742 |
+
# "left_knee":3,
|
| 743 |
+
# "right_knee":3,
|
| 744 |
+
# # "spine2":3,
|
| 745 |
+
# "left_ankle":3,
|
| 746 |
+
# "right_ankle":3,
|
| 747 |
+
# # "spine3":3,
|
| 748 |
+
# "left_foot":3,
|
| 749 |
+
# "right_foot":3,
|
| 750 |
+
# "neck":3,
|
| 751 |
+
# "left_collar":3,
|
| 752 |
+
# "right_collar":3,
|
| 753 |
+
# "head":3,
|
| 754 |
+
# "left_shoulder":3,
|
| 755 |
+
# "right_shoulder":3,
|
| 756 |
+
# "left_elbow":3,
|
| 757 |
+
# "right_elbow":3,
|
| 758 |
+
# "left_wrist":3,
|
| 759 |
+
# "right_wrist":3,
|
| 760 |
+
"jaw":3,
|
| 761 |
+
# "left_eye_smplhf":3,
|
| 762 |
+
# "right_eye_smplhf":3,
|
| 763 |
+
# "left_index1":3,
|
| 764 |
+
# "left_index2":3,
|
| 765 |
+
# "left_index3":3,
|
| 766 |
+
# "left_middle1":3,
|
| 767 |
+
# "left_middle2":3,
|
| 768 |
+
# "left_middle3":3,
|
| 769 |
+
# "left_pinky1":3,
|
| 770 |
+
# "left_pinky2":3,
|
| 771 |
+
# "left_pinky3":3,
|
| 772 |
+
# "left_ring1":3,
|
| 773 |
+
# "left_ring2":3,
|
| 774 |
+
# "left_ring3":3,
|
| 775 |
+
# "left_thumb1":3,
|
| 776 |
+
# "left_thumb2":3,
|
| 777 |
+
# "left_thumb3":3,
|
| 778 |
+
# "right_index1":3,
|
| 779 |
+
# "right_index2":3,
|
| 780 |
+
# "right_index3":3,
|
| 781 |
+
# "right_middle1":3,
|
| 782 |
+
# "right_middle2":3,
|
| 783 |
+
# "right_middle3":3,
|
| 784 |
+
# "right_pinky1":3,
|
| 785 |
+
# "right_pinky2":3,
|
| 786 |
+
# "right_pinky3":3,
|
| 787 |
+
# "right_ring1":3,
|
| 788 |
+
# "right_ring2":3,
|
| 789 |
+
# "right_ring3":3,
|
| 790 |
+
# "right_thumb1":3,
|
| 791 |
+
# "right_thumb2":3,
|
| 792 |
+
# "right_thumb3":3,
|
| 793 |
+
},
|
| 794 |
+
|
| 795 |
+
"beat_joints": {
|
| 796 |
+
'Hips': [6,6],
|
| 797 |
+
'Spine': [3,9],
|
| 798 |
+
'Spine1': [3,12],
|
| 799 |
+
'Spine2': [3,15],
|
| 800 |
+
'Spine3': [3,18],
|
| 801 |
+
'Neck': [3,21],
|
| 802 |
+
'Neck1': [3,24],
|
| 803 |
+
'Head': [3,27],
|
| 804 |
+
'HeadEnd': [3,30],
|
| 805 |
+
|
| 806 |
+
'RShoulder': [3,33],
|
| 807 |
+
'RArm': [3,36],
|
| 808 |
+
'RArm1': [3,39],
|
| 809 |
+
'RHand': [3,42],
|
| 810 |
+
'RHandM1': [3,45],
|
| 811 |
+
'RHandM2': [3,48],
|
| 812 |
+
'RHandM3': [3,51],
|
| 813 |
+
'RHandM4': [3,54],
|
| 814 |
+
|
| 815 |
+
'RHandR': [3,57],
|
| 816 |
+
'RHandR1': [3,60],
|
| 817 |
+
'RHandR2': [3,63],
|
| 818 |
+
'RHandR3': [3,66],
|
| 819 |
+
'RHandR4': [3,69],
|
| 820 |
+
|
| 821 |
+
'RHandP': [3,72],
|
| 822 |
+
'RHandP1': [3,75],
|
| 823 |
+
'RHandP2': [3,78],
|
| 824 |
+
'RHandP3': [3,81],
|
| 825 |
+
'RHandP4': [3,84],
|
| 826 |
+
|
| 827 |
+
'RHandI': [3,87],
|
| 828 |
+
'RHandI1': [3,90],
|
| 829 |
+
'RHandI2': [3,93],
|
| 830 |
+
'RHandI3': [3,96],
|
| 831 |
+
'RHandI4': [3,99],
|
| 832 |
+
|
| 833 |
+
'RHandT1': [3,102],
|
| 834 |
+
'RHandT2': [3,105],
|
| 835 |
+
'RHandT3': [3,108],
|
| 836 |
+
'RHandT4': [3,111],
|
| 837 |
+
|
| 838 |
+
'LShoulder': [3,114],
|
| 839 |
+
'LArm': [3,117],
|
| 840 |
+
'LArm1': [3,120],
|
| 841 |
+
'LHand': [3,123],
|
| 842 |
+
'LHandM1': [3,126],
|
| 843 |
+
'LHandM2': [3,129],
|
| 844 |
+
'LHandM3': [3,132],
|
| 845 |
+
'LHandM4': [3,135],
|
| 846 |
+
|
| 847 |
+
'LHandR': [3,138],
|
| 848 |
+
'LHandR1': [3,141],
|
| 849 |
+
'LHandR2': [3,144],
|
| 850 |
+
'LHandR3': [3,147],
|
| 851 |
+
'LHandR4': [3,150],
|
| 852 |
+
|
| 853 |
+
'LHandP': [3,153],
|
| 854 |
+
'LHandP1': [3,156],
|
| 855 |
+
'LHandP2': [3,159],
|
| 856 |
+
'LHandP3': [3,162],
|
| 857 |
+
'LHandP4': [3,165],
|
| 858 |
+
|
| 859 |
+
'LHandI': [3,168],
|
| 860 |
+
'LHandI1': [3,171],
|
| 861 |
+
'LHandI2': [3,174],
|
| 862 |
+
'LHandI3': [3,177],
|
| 863 |
+
'LHandI4': [3,180],
|
| 864 |
+
|
| 865 |
+
'LHandT1': [3,183],
|
| 866 |
+
'LHandT2': [3,186],
|
| 867 |
+
'LHandT3': [3,189],
|
| 868 |
+
'LHandT4': [3,192],
|
| 869 |
+
|
| 870 |
+
'RUpLeg': [3,195],
|
| 871 |
+
'RLeg': [3,198],
|
| 872 |
+
'RFoot': [3,201],
|
| 873 |
+
'RFootF': [3,204],
|
| 874 |
+
'RToeBase': [3,207],
|
| 875 |
+
'RToeBaseEnd': [3,210],
|
| 876 |
+
|
| 877 |
+
'LUpLeg': [3,213],
|
| 878 |
+
'LLeg': [3,216],
|
| 879 |
+
'LFoot': [3,219],
|
| 880 |
+
'LFootF': [3,222],
|
| 881 |
+
'LToeBase': [3,225],
|
| 882 |
+
'LToeBaseEnd': [3,228],},
|
| 883 |
+
|
| 884 |
+
"beat_full":{
|
| 885 |
+
'Hips': 3,
|
| 886 |
+
'Spine': 3 ,
|
| 887 |
+
'Spine1': 3 ,
|
| 888 |
+
'Spine2': 3 ,
|
| 889 |
+
'Spine3': 3 ,
|
| 890 |
+
'Neck': 3 ,
|
| 891 |
+
'Neck1': 3 ,
|
| 892 |
+
'Head' : 3,
|
| 893 |
+
'HeadEnd' : 3,
|
| 894 |
+
'RShoulder': 3 ,
|
| 895 |
+
'RArm': 3 ,
|
| 896 |
+
'RArm1': 3 ,
|
| 897 |
+
'RHand': 3 ,
|
| 898 |
+
'RHandM1': 3 ,
|
| 899 |
+
'RHandM2': 3 ,
|
| 900 |
+
'RHandM3': 3 ,
|
| 901 |
+
'RHandM4': 3 ,
|
| 902 |
+
'RHandR': 3 ,
|
| 903 |
+
'RHandR1': 3 ,
|
| 904 |
+
'RHandR2': 3 ,
|
| 905 |
+
'RHandR3': 3 ,
|
| 906 |
+
'RHandR4': 3 ,
|
| 907 |
+
'RHandP': 3 ,
|
| 908 |
+
'RHandP1': 3 ,
|
| 909 |
+
'RHandP2': 3 ,
|
| 910 |
+
'RHandP3': 3 ,
|
| 911 |
+
'RHandP4': 3 ,
|
| 912 |
+
'RHandI': 3 ,
|
| 913 |
+
'RHandI1': 3 ,
|
| 914 |
+
'RHandI2': 3 ,
|
| 915 |
+
'RHandI3': 3 ,
|
| 916 |
+
'RHandI4': 3 ,
|
| 917 |
+
'RHandT1': 3 ,
|
| 918 |
+
'RHandT2': 3 ,
|
| 919 |
+
'RHandT3': 3 ,
|
| 920 |
+
'RHandT4': 3 ,
|
| 921 |
+
'LShoulder': 3 ,
|
| 922 |
+
'LArm': 3 ,
|
| 923 |
+
'LArm1': 3 ,
|
| 924 |
+
'LHand': 3 ,
|
| 925 |
+
'LHandM1': 3 ,
|
| 926 |
+
'LHandM2': 3 ,
|
| 927 |
+
'LHandM3': 3 ,
|
| 928 |
+
'LHandM4': 3 ,
|
| 929 |
+
'LHandR': 3 ,
|
| 930 |
+
'LHandR1': 3 ,
|
| 931 |
+
'LHandR2': 3 ,
|
| 932 |
+
'LHandR3': 3 ,
|
| 933 |
+
'LHandR4': 3 ,
|
| 934 |
+
'LHandP': 3 ,
|
| 935 |
+
'LHandP1': 3 ,
|
| 936 |
+
'LHandP2': 3 ,
|
| 937 |
+
'LHandP3': 3 ,
|
| 938 |
+
'LHandP4': 3 ,
|
| 939 |
+
'LHandI': 3 ,
|
| 940 |
+
'LHandI1': 3 ,
|
| 941 |
+
'LHandI2': 3 ,
|
| 942 |
+
'LHandI3': 3 ,
|
| 943 |
+
'LHandI4': 3 ,
|
| 944 |
+
'LHandT1': 3 ,
|
| 945 |
+
'LHandT2': 3 ,
|
| 946 |
+
'LHandT3': 3 ,
|
| 947 |
+
'LHandT4': 3 ,
|
| 948 |
+
'RUpLeg': 3,
|
| 949 |
+
'RLeg': 3,
|
| 950 |
+
'RFoot': 3,
|
| 951 |
+
'RFootF': 3,
|
| 952 |
+
'RToeBase': 3,
|
| 953 |
+
'RToeBaseEnd': 3,
|
| 954 |
+
'LUpLeg': 3,
|
| 955 |
+
'LLeg': 3,
|
| 956 |
+
'LFoot': 3,
|
| 957 |
+
'LFootF': 3,
|
| 958 |
+
'LToeBase': 3,
|
| 959 |
+
'LToeBaseEnd': 3,
|
| 960 |
+
},
|
| 961 |
+
|
| 962 |
+
"japanese_joints":{
|
| 963 |
+
'Hips': [6,6],
|
| 964 |
+
'Spine': [6,12],
|
| 965 |
+
'Spine1': [6,18],
|
| 966 |
+
'Spine2': [6,24],
|
| 967 |
+
'Spine3': [6,30],
|
| 968 |
+
'Neck': [6,36],
|
| 969 |
+
'Neck1': [6,42],
|
| 970 |
+
'Head': [6,48],
|
| 971 |
+
'RShoulder': [6,54],
|
| 972 |
+
'RArm': [6,60],
|
| 973 |
+
'RArm1': [6,66],
|
| 974 |
+
'RHand': [6,72],
|
| 975 |
+
'RHandM1': [6,78],
|
| 976 |
+
'RHandM2': [6,84],
|
| 977 |
+
'RHandM3': [6,90],
|
| 978 |
+
'RHandR': [6,96],
|
| 979 |
+
'RHandR1': [6,102],
|
| 980 |
+
'RHandR2': [6,108],
|
| 981 |
+
'RHandR3': [6,114],
|
| 982 |
+
'RHandP': [6,120],
|
| 983 |
+
'RHandP1': [6,126],
|
| 984 |
+
'RHandP2': [6,132],
|
| 985 |
+
'RHandP3': [6,138],
|
| 986 |
+
'RHandI': [6,144],
|
| 987 |
+
'RHandI1': [6,150],
|
| 988 |
+
'RHandI2': [6,156],
|
| 989 |
+
'RHandI3': [6,162],
|
| 990 |
+
'RHandT1': [6,168],
|
| 991 |
+
'RHandT2': [6,174],
|
| 992 |
+
'RHandT3': [6,180],
|
| 993 |
+
'LShoulder': [6,186],
|
| 994 |
+
'LArm': [6,192],
|
| 995 |
+
'LArm1': [6,198],
|
| 996 |
+
'LHand': [6,204],
|
| 997 |
+
'LHandM1': [6,210],
|
| 998 |
+
'LHandM2': [6,216],
|
| 999 |
+
'LHandM3': [6,222],
|
| 1000 |
+
'LHandR': [6,228],
|
| 1001 |
+
'LHandR1': [6,234],
|
| 1002 |
+
'LHandR2': [6,240],
|
| 1003 |
+
'LHandR3': [6,246],
|
| 1004 |
+
'LHandP': [6,252],
|
| 1005 |
+
'LHandP1': [6,258],
|
| 1006 |
+
'LHandP2': [6,264],
|
| 1007 |
+
'LHandP3': [6,270],
|
| 1008 |
+
'LHandI': [6,276],
|
| 1009 |
+
'LHandI1': [6,282],
|
| 1010 |
+
'LHandI2': [6,288],
|
| 1011 |
+
'LHandI3': [6,294],
|
| 1012 |
+
'LHandT1': [6,300],
|
| 1013 |
+
'LHandT2': [6,306],
|
| 1014 |
+
'LHandT3': [6,312],
|
| 1015 |
+
'RUpLeg': [6,318],
|
| 1016 |
+
'RLeg': [6,324],
|
| 1017 |
+
'RFoot': [6,330],
|
| 1018 |
+
'RFootF': [6,336],
|
| 1019 |
+
'RToeBase': [6,342],
|
| 1020 |
+
'LUpLeg': [6,348],
|
| 1021 |
+
'LLeg': [6,354],
|
| 1022 |
+
'LFoot': [6,360],
|
| 1023 |
+
'LFootF': [6,366],
|
| 1024 |
+
'LToeBase': [6,372],},
|
| 1025 |
+
|
| 1026 |
+
"yostar":{
|
| 1027 |
+
'Hips': [6,6],
|
| 1028 |
+
'Spine': [3,9],
|
| 1029 |
+
'Spine1': [3,12],
|
| 1030 |
+
'Bone040': [3,15],
|
| 1031 |
+
'Bone041': [3,18],
|
| 1032 |
+
|
| 1033 |
+
'Bone034': [3,21],
|
| 1034 |
+
'Bone035': [3,24],
|
| 1035 |
+
'Bone036': [3,27],
|
| 1036 |
+
'Bone037': [3,30],
|
| 1037 |
+
'Bone038': [3,33],
|
| 1038 |
+
'Bone039': [3,36],
|
| 1039 |
+
|
| 1040 |
+
'RibbonL1': [3,39],
|
| 1041 |
+
'RibbonL1_end': [3,42],
|
| 1042 |
+
|
| 1043 |
+
'Chest': [3,45],
|
| 1044 |
+
'L_eri': [3,48],
|
| 1045 |
+
'R_eri': [3,51],
|
| 1046 |
+
'Neck': [3,54],
|
| 1047 |
+
'Head': [3,57],
|
| 1048 |
+
'Head_end': [3,60],
|
| 1049 |
+
|
| 1050 |
+
'RBackHair_1': [3,63],
|
| 1051 |
+
'RBackHair_2': [3,66],
|
| 1052 |
+
'RBackHair_3': [3,69],
|
| 1053 |
+
'RBackHair_4': [3,72],
|
| 1054 |
+
'RBackHair_end': [3,75],
|
| 1055 |
+
|
| 1056 |
+
'RFrontHair': [3,78],
|
| 1057 |
+
'CFrontHair_1': [3,81],
|
| 1058 |
+
'CFrontHair_2': [3,84],
|
| 1059 |
+
'CFrontHair_3': [3,87],
|
| 1060 |
+
'CFrontHair_emd': [3,90],
|
| 1061 |
+
|
| 1062 |
+
'LFrontHair_1': [3,93],
|
| 1063 |
+
'LFrontHair_2': [3,96],
|
| 1064 |
+
'LFrontHair_3': [3,99],
|
| 1065 |
+
|
| 1066 |
+
'LBackHair_1': [3,102],
|
| 1067 |
+
'LBackHair_2': [3,105],
|
| 1068 |
+
'LBackHair_3': [3,108],
|
| 1069 |
+
'LBackHair_4': [3,111],
|
| 1070 |
+
'LBackHair_end': [3,114],
|
| 1071 |
+
|
| 1072 |
+
'LSideHair_1': [3,117],
|
| 1073 |
+
'LSideHair_2': [3,120],
|
| 1074 |
+
'LSideHair_3': [3,123],
|
| 1075 |
+
'LSideHair_4': [3,126],
|
| 1076 |
+
'LSideHair_5': [3,129],
|
| 1077 |
+
'LSideHair_6': [3,132],
|
| 1078 |
+
'LSideHair_7': [3,135],
|
| 1079 |
+
'LSideHair_end': [3,138],
|
| 1080 |
+
|
| 1081 |
+
'CBackHair_1': [3,141],
|
| 1082 |
+
'CBackHair_2': [3,144],
|
| 1083 |
+
'CBackHair_3': [3,147],
|
| 1084 |
+
'CBackHair_4': [3,150],
|
| 1085 |
+
'CBackHair_end': [3,153],
|
| 1086 |
+
|
| 1087 |
+
'RSideHair_1': [3,156],
|
| 1088 |
+
'RSideHair_2': [3,159],
|
| 1089 |
+
'RSideHair_3': [3,162],
|
| 1090 |
+
'RSideHair_4': [3,165],
|
| 1091 |
+
|
| 1092 |
+
'RibbonR_1': [3,168],
|
| 1093 |
+
'RibbonR_2': [3,171],
|
| 1094 |
+
'RibbonR_3': [3,174],
|
| 1095 |
+
|
| 1096 |
+
'RibbonL_1': [3,177],
|
| 1097 |
+
'RibbonL_2': [3,180],
|
| 1098 |
+
'RibbonL_3': [3,183],
|
| 1099 |
+
|
| 1100 |
+
'LeftEye': [3,186],
|
| 1101 |
+
'LeftEye_end': [3,189],
|
| 1102 |
+
'RightEye': [3,192],
|
| 1103 |
+
'RightEye_end': [3,195],
|
| 1104 |
+
|
| 1105 |
+
'LeftShoulder': [3,198],
|
| 1106 |
+
'LeftArm': [3,201],
|
| 1107 |
+
'LeftForearm': [3,204],
|
| 1108 |
+
'LeftHand': [3,207],
|
| 1109 |
+
'LeftHandThumb1': [3,210],
|
| 1110 |
+
'LeftHandThumb2': [3,213],
|
| 1111 |
+
'LeftHandThumb3': [3,216],
|
| 1112 |
+
'LeftHandThumb_end': [3,219],
|
| 1113 |
+
|
| 1114 |
+
'LeftHandIndex1': [3,222],
|
| 1115 |
+
'LeftHandIndex2': [3,225],
|
| 1116 |
+
'LeftHandIndex3': [3,228],
|
| 1117 |
+
'LeftHandIndex_end': [3,231],
|
| 1118 |
+
|
| 1119 |
+
'LeftHandMiddle1': [3,234],
|
| 1120 |
+
'LeftHandMiddle2': [3,237],
|
| 1121 |
+
'LeftHandMiddle3': [3,240],
|
| 1122 |
+
'LeftHandMiddle_end': [3,243],
|
| 1123 |
+
|
| 1124 |
+
'LeftHandRing1': [3,246],
|
| 1125 |
+
'LeftHandRing2': [3,249],
|
| 1126 |
+
'LeftHandRing3': [3,252],
|
| 1127 |
+
'LeftHandRing_end': [3,255],
|
| 1128 |
+
|
| 1129 |
+
'LeftHandPinky1': [3,258],
|
| 1130 |
+
'LeftHandPinky2': [3,261],
|
| 1131 |
+
'LeftHandPinky3': [3,264],
|
| 1132 |
+
'LeftHandPinky_end': [3,267],
|
| 1133 |
+
|
| 1134 |
+
'RightShoulder': [3,270],
|
| 1135 |
+
'RightArm': [3,273],
|
| 1136 |
+
'RightForearm': [3,276],
|
| 1137 |
+
'RightHand': [3,279],
|
| 1138 |
+
'RightHandThumb1': [3,282],
|
| 1139 |
+
'RightHandThumb2': [3,285],
|
| 1140 |
+
'RightHandThumb3': [3,288],
|
| 1141 |
+
'RightHandThumb_end': [3,291],
|
| 1142 |
+
|
| 1143 |
+
'RightHandIndex1': [3,294],
|
| 1144 |
+
'RightHandIndex2': [3,297],
|
| 1145 |
+
'RightHandIndex3': [3,300],
|
| 1146 |
+
'RightHandIndex_end': [3,303],
|
| 1147 |
+
|
| 1148 |
+
'RightHandMiddle1': [3,306],
|
| 1149 |
+
'RightHandMiddle2': [3,309],
|
| 1150 |
+
'RightHandMiddle3': [3,312],
|
| 1151 |
+
'RightHandMiddle_end': [3,315],
|
| 1152 |
+
|
| 1153 |
+
'RightHandRing1': [3,318],
|
| 1154 |
+
'RightHandRing2': [3,321],
|
| 1155 |
+
'RightHandRing3': [3,324],
|
| 1156 |
+
'RightHandRing_end': [3,327],
|
| 1157 |
+
|
| 1158 |
+
'RightHandPinky1': [3,330],
|
| 1159 |
+
'RightHandPinky2': [3,333],
|
| 1160 |
+
'RightHandPinky3': [3,336],
|
| 1161 |
+
'RightHandPinky_end': [3,339],
|
| 1162 |
+
|
| 1163 |
+
'RibbonR1': [3,342],
|
| 1164 |
+
'RibbonR1_end': [3,345],
|
| 1165 |
+
'RibbonR2': [3,348],
|
| 1166 |
+
'RibbonR2_end': [3,351],
|
| 1167 |
+
'RibbonL2': [3,354],
|
| 1168 |
+
'RibbonL2_end': [3,357],
|
| 1169 |
+
|
| 1170 |
+
'LeftUpLeg': [3,360],
|
| 1171 |
+
'LeftLeg': [3,363],
|
| 1172 |
+
'LeftFoot': [3,366],
|
| 1173 |
+
'LeftToe': [3,369],
|
| 1174 |
+
'LeftToe_end': [3,372],
|
| 1175 |
+
|
| 1176 |
+
'RightUpLeg': [3,375],
|
| 1177 |
+
'RightLEg': [3,378],
|
| 1178 |
+
'RightFoot': [3,381],
|
| 1179 |
+
'RightToe': [3,384],
|
| 1180 |
+
'RightToe_end': [3,387],
|
| 1181 |
+
|
| 1182 |
+
'bone_skirtF00': [3, 390],
|
| 1183 |
+
'bone_skirtF01': [3, 393],
|
| 1184 |
+
'bone_skirtF02': [3, 396],
|
| 1185 |
+
'bone_skirtF03': [3, 399],
|
| 1186 |
+
'Bone020': [3, 402],
|
| 1187 |
+
'Bone026': [3, 405],
|
| 1188 |
+
|
| 1189 |
+
'bone_skirtF_R_00': [3, 408],
|
| 1190 |
+
'bone_skirtF_R_01': [3, 411],
|
| 1191 |
+
'bone_skirtF_R_02': [3, 414],
|
| 1192 |
+
'bone_skirtF_R_03': [3, 417],
|
| 1193 |
+
'Bone019': [3, 420],
|
| 1194 |
+
'Bone028': [3, 423],
|
| 1195 |
+
|
| 1196 |
+
'bone_skirtR00': [3, 426],
|
| 1197 |
+
'bone_skirtR01': [3, 429],
|
| 1198 |
+
'bone_skirtR02': [3, 432],
|
| 1199 |
+
'bone_skirtR03': [3, 435],
|
| 1200 |
+
'Bone018': [3, 438],
|
| 1201 |
+
'Bone029': [3, 441],
|
| 1202 |
+
|
| 1203 |
+
'bone_skirtF_L_00': [3, 444],
|
| 1204 |
+
'bone_skirtF_L_01': [3, 447],
|
| 1205 |
+
'bone_skirtF_L_02': [3, 450],
|
| 1206 |
+
'bone_skirtF_L_03': [3, 453],
|
| 1207 |
+
'Bone021': [3, 456],
|
| 1208 |
+
'Bone027': [3, 459],
|
| 1209 |
+
|
| 1210 |
+
'bone_skirtL00': [3, 462],
|
| 1211 |
+
'bone_skirtL01': [3, 465],
|
| 1212 |
+
'bone_skirtL02': [3, 468],
|
| 1213 |
+
'bone_skirtL03': [3, 471],
|
| 1214 |
+
'Bone022': [3, 474],
|
| 1215 |
+
'Bone033': [3, 477],
|
| 1216 |
+
|
| 1217 |
+
'bone_skirtB_L_00': [3, 480],
|
| 1218 |
+
'bone_skirtB_L_01': [3, 483],
|
| 1219 |
+
'bone_skirtB_L_02': [3, 486],
|
| 1220 |
+
'bone_skirtB_L_03': [3, 489],
|
| 1221 |
+
'Bone023': [3, 492],
|
| 1222 |
+
'Bone032': [3, 495],
|
| 1223 |
+
|
| 1224 |
+
'bone_skirtB00': [3, 498],
|
| 1225 |
+
'bone_skirtB01': [3, 501],
|
| 1226 |
+
'bone_skirtB02': [3, 504],
|
| 1227 |
+
'bone_skirtB03': [3, 507],
|
| 1228 |
+
'Bone024': [3, 510],
|
| 1229 |
+
'Bone031': [3, 513],
|
| 1230 |
+
|
| 1231 |
+
'bone_skirtB_R_00': [3, 516],
|
| 1232 |
+
'bone_skirtB_R_01': [3, 519],
|
| 1233 |
+
'bone_skirtB_R_02': [3, 521],
|
| 1234 |
+
'bone_skirtB_R_03': [3, 524],
|
| 1235 |
+
'Bone025': [3, 527],
|
| 1236 |
+
'Bone030': [3, 530],
|
| 1237 |
+
},
|
| 1238 |
+
|
| 1239 |
+
"yostar_fullbody_213":{
|
| 1240 |
+
'Hips': 3 ,
|
| 1241 |
+
'Spine': 3 ,
|
| 1242 |
+
'Spine1': 3 ,
|
| 1243 |
+
'Chest': 3 ,
|
| 1244 |
+
'L_eri': 3 ,
|
| 1245 |
+
'R_eri': 3 ,
|
| 1246 |
+
'Neck': 3 ,
|
| 1247 |
+
'Head': 3 ,
|
| 1248 |
+
'Head_end': 3 ,
|
| 1249 |
+
|
| 1250 |
+
'LeftEye': 3,
|
| 1251 |
+
'LeftEye_end': 3,
|
| 1252 |
+
'RightEye': 3,
|
| 1253 |
+
'RightEye_end': 3,
|
| 1254 |
+
|
| 1255 |
+
'LeftShoulder': 3,
|
| 1256 |
+
'LeftArm': 3,
|
| 1257 |
+
'LeftForearm': 3,
|
| 1258 |
+
'LeftHand': 3,
|
| 1259 |
+
'LeftHandThumb1': 3,
|
| 1260 |
+
'LeftHandThumb2': 3,
|
| 1261 |
+
'LeftHandThumb3': 3,
|
| 1262 |
+
'LeftHandThumb_end': 3,
|
| 1263 |
+
|
| 1264 |
+
'LeftHandIndex1': 3,
|
| 1265 |
+
'LeftHandIndex2': 3,
|
| 1266 |
+
'LeftHandIndex3': 3,
|
| 1267 |
+
'LeftHandIndex_end': 3,
|
| 1268 |
+
|
| 1269 |
+
'LeftHandMiddle1': 3,
|
| 1270 |
+
'LeftHandMiddle2': 3,
|
| 1271 |
+
'LeftHandMiddle3': 3,
|
| 1272 |
+
'LeftHandMiddle_end': 3,
|
| 1273 |
+
|
| 1274 |
+
'LeftHandRing1': 3,
|
| 1275 |
+
'LeftHandRing2': 3,
|
| 1276 |
+
'LeftHandRing3': 3,
|
| 1277 |
+
'LeftHandRing_end': 3,
|
| 1278 |
+
|
| 1279 |
+
'LeftHandPinky1': 3,
|
| 1280 |
+
'LeftHandPinky2': 3,
|
| 1281 |
+
'LeftHandPinky3': 3,
|
| 1282 |
+
'LeftHandPinky_end':3,
|
| 1283 |
+
|
| 1284 |
+
'RightShoulder': 3,
|
| 1285 |
+
'RightArm': 3,
|
| 1286 |
+
'RightForearm': 3,
|
| 1287 |
+
'RightHand': 3,
|
| 1288 |
+
'RightHandThumb1': 3,
|
| 1289 |
+
'RightHandThumb2': 3,
|
| 1290 |
+
'RightHandThumb3': 3,
|
| 1291 |
+
'RightHandThumb_end': 3,
|
| 1292 |
+
|
| 1293 |
+
'RightHandIndex1': 3,
|
| 1294 |
+
'RightHandIndex2': 3,
|
| 1295 |
+
'RightHandIndex3': 3,
|
| 1296 |
+
'RightHandIndex_end': 3,
|
| 1297 |
+
|
| 1298 |
+
'RightHandMiddle1': 3,
|
| 1299 |
+
'RightHandMiddle2': 3,
|
| 1300 |
+
'RightHandMiddle3': 3,
|
| 1301 |
+
'RightHandMiddle_end': 3,
|
| 1302 |
+
|
| 1303 |
+
'RightHandRing1': 3,
|
| 1304 |
+
'RightHandRing2': 3,
|
| 1305 |
+
'RightHandRing3': 3,
|
| 1306 |
+
'RightHandRing_end': 3,
|
| 1307 |
+
|
| 1308 |
+
'RightHandPinky1': 3,
|
| 1309 |
+
'RightHandPinky2': 3,
|
| 1310 |
+
'RightHandPinky3': 3,
|
| 1311 |
+
'RightHandPinky_end': 3,
|
| 1312 |
+
|
| 1313 |
+
'LeftUpLeg': 3,
|
| 1314 |
+
'LeftLeg': 3,
|
| 1315 |
+
'LeftFoot': 3,
|
| 1316 |
+
'LeftToe': 3,
|
| 1317 |
+
'LeftToe_end': 3,
|
| 1318 |
+
|
| 1319 |
+
'RightUpLeg': 3,
|
| 1320 |
+
'RightLEg': 3,
|
| 1321 |
+
'RightFoot': 3,
|
| 1322 |
+
'RightToe': 3,
|
| 1323 |
+
'RightToe_end': 3,
|
| 1324 |
+
},
|
| 1325 |
+
"yostar_mainbody_48": {
|
| 1326 |
+
#'Hips': 3 ,
|
| 1327 |
+
'Spine': 3 ,
|
| 1328 |
+
'Spine1': 3 ,
|
| 1329 |
+
'Chest': 3 ,
|
| 1330 |
+
'L_eri': 3 ,
|
| 1331 |
+
'R_eri': 3 ,
|
| 1332 |
+
'Neck': 3 ,
|
| 1333 |
+
'Head': 3 ,
|
| 1334 |
+
'Head_end': 3 ,
|
| 1335 |
+
|
| 1336 |
+
'LeftShoulder': 3,
|
| 1337 |
+
'LeftArm': 3,
|
| 1338 |
+
'LeftForearm': 3,
|
| 1339 |
+
'LeftHand': 3,
|
| 1340 |
+
|
| 1341 |
+
'RightShoulder': 3,
|
| 1342 |
+
'RightArm': 3,
|
| 1343 |
+
'RightForearm': 3,
|
| 1344 |
+
'RightHand': 3,
|
| 1345 |
+
},
|
| 1346 |
+
"yostar_mainbody_69": {
|
| 1347 |
+
'Hips': 3 ,
|
| 1348 |
+
'Spine': 3 ,
|
| 1349 |
+
'Spine1': 3 ,
|
| 1350 |
+
'Chest': 3 ,
|
| 1351 |
+
'L_eri': 3 ,
|
| 1352 |
+
'R_eri': 3 ,
|
| 1353 |
+
'Neck': 3 ,
|
| 1354 |
+
'Head': 3 ,
|
| 1355 |
+
'Head_end': 3 ,
|
| 1356 |
+
|
| 1357 |
+
'LeftShoulder': 3,
|
| 1358 |
+
'LeftArm': 3,
|
| 1359 |
+
'LeftForearm': 3,
|
| 1360 |
+
'LeftHand': 3,
|
| 1361 |
+
|
| 1362 |
+
'RightShoulder': 3,
|
| 1363 |
+
'RightArm': 3,
|
| 1364 |
+
'RightForearm': 3,
|
| 1365 |
+
'RightHand': 3,
|
| 1366 |
+
|
| 1367 |
+
'LeftUpLeg': 3,
|
| 1368 |
+
'LeftLeg': 3,
|
| 1369 |
+
'LeftFoot': 3,
|
| 1370 |
+
|
| 1371 |
+
'RightUpLeg': 3,
|
| 1372 |
+
'RightLEg': 3,
|
| 1373 |
+
'RightFoot': 3,
|
| 1374 |
+
},
|
| 1375 |
+
|
| 1376 |
+
"yostar_upbody_168": {
|
| 1377 |
+
#'Hips': 3 ,
|
| 1378 |
+
'Spine': 3 ,
|
| 1379 |
+
'Spine1': 3 ,
|
| 1380 |
+
'Chest': 3 ,
|
| 1381 |
+
'L_eri': 3 ,
|
| 1382 |
+
'R_eri': 3 ,
|
| 1383 |
+
'Neck': 3 ,
|
| 1384 |
+
'Head': 3 ,
|
| 1385 |
+
'Head_end': 3 ,
|
| 1386 |
+
|
| 1387 |
+
'LeftShoulder': 3,
|
| 1388 |
+
'LeftArm': 3,
|
| 1389 |
+
'LeftForearm': 3,
|
| 1390 |
+
'LeftHand': 3,
|
| 1391 |
+
'LeftHandThumb1': 3,
|
| 1392 |
+
'LeftHandThumb2': 3,
|
| 1393 |
+
'LeftHandThumb3': 3,
|
| 1394 |
+
'LeftHandThumb_end': 3,
|
| 1395 |
+
|
| 1396 |
+
'LeftHandIndex1': 3,
|
| 1397 |
+
'LeftHandIndex2': 3,
|
| 1398 |
+
'LeftHandIndex3': 3,
|
| 1399 |
+
'LeftHandIndex_end': 3,
|
| 1400 |
+
|
| 1401 |
+
'LeftHandMiddle1': 3,
|
| 1402 |
+
'LeftHandMiddle2': 3,
|
| 1403 |
+
'LeftHandMiddle3': 3,
|
| 1404 |
+
'LeftHandMiddle_end': 3,
|
| 1405 |
+
|
| 1406 |
+
'LeftHandRing1': 3,
|
| 1407 |
+
'LeftHandRing2': 3,
|
| 1408 |
+
'LeftHandRing3': 3,
|
| 1409 |
+
'LeftHandRing_end': 3,
|
| 1410 |
+
|
| 1411 |
+
'LeftHandPinky1': 3,
|
| 1412 |
+
'LeftHandPinky2': 3,
|
| 1413 |
+
'LeftHandPinky3': 3,
|
| 1414 |
+
'LeftHandPinky_end':3,
|
| 1415 |
+
|
| 1416 |
+
'RightShoulder': 3,
|
| 1417 |
+
'RightArm': 3,
|
| 1418 |
+
'RightForearm': 3,
|
| 1419 |
+
'RightHand': 3,
|
| 1420 |
+
'RightHandThumb1': 3,
|
| 1421 |
+
'RightHandThumb2': 3,
|
| 1422 |
+
'RightHandThumb3': 3,
|
| 1423 |
+
'RightHandThumb_end': 3,
|
| 1424 |
+
|
| 1425 |
+
'RightHandIndex1': 3,
|
| 1426 |
+
'RightHandIndex2': 3,
|
| 1427 |
+
'RightHandIndex3': 3,
|
| 1428 |
+
'RightHandIndex_end': 3,
|
| 1429 |
+
|
| 1430 |
+
'RightHandMiddle1': 3,
|
| 1431 |
+
'RightHandMiddle2': 3,
|
| 1432 |
+
'RightHandMiddle3': 3,
|
| 1433 |
+
'RightHandMiddle_end': 3,
|
| 1434 |
+
|
| 1435 |
+
'RightHandRing1': 3,
|
| 1436 |
+
'RightHandRing2': 3,
|
| 1437 |
+
'RightHandRing3': 3,
|
| 1438 |
+
'RightHandRing_end': 3,
|
| 1439 |
+
|
| 1440 |
+
'RightHandPinky1': 3,
|
| 1441 |
+
'RightHandPinky2': 3,
|
| 1442 |
+
'RightHandPinky3': 3,
|
| 1443 |
+
'RightHandPinky_end': 3,
|
| 1444 |
+
},
|
| 1445 |
+
"spine_neck_141":{
|
| 1446 |
+
'Spine': 3 ,
|
| 1447 |
+
'Neck': 3 ,
|
| 1448 |
+
'Neck1': 3 ,
|
| 1449 |
+
'RShoulder': 3 ,
|
| 1450 |
+
'RArm': 3 ,
|
| 1451 |
+
'RArm1': 3 ,
|
| 1452 |
+
'RHand': 3 ,
|
| 1453 |
+
'RHandM1': 3 ,
|
| 1454 |
+
'RHandM2': 3 ,
|
| 1455 |
+
'RHandM3': 3 ,
|
| 1456 |
+
'RHandR': 3 ,
|
| 1457 |
+
'RHandR1': 3 ,
|
| 1458 |
+
'RHandR2': 3 ,
|
| 1459 |
+
'RHandR3': 3 ,
|
| 1460 |
+
'RHandP': 3 ,
|
| 1461 |
+
'RHandP1': 3 ,
|
| 1462 |
+
'RHandP2': 3 ,
|
| 1463 |
+
'RHandP3': 3 ,
|
| 1464 |
+
'RHandI': 3 ,
|
| 1465 |
+
'RHandI1': 3 ,
|
| 1466 |
+
'RHandI2': 3 ,
|
| 1467 |
+
'RHandI3': 3 ,
|
| 1468 |
+
'RHandT1': 3 ,
|
| 1469 |
+
'RHandT2': 3 ,
|
| 1470 |
+
'RHandT3': 3 ,
|
| 1471 |
+
'LShoulder': 3 ,
|
| 1472 |
+
'LArm': 3 ,
|
| 1473 |
+
'LArm1': 3 ,
|
| 1474 |
+
'LHand': 3 ,
|
| 1475 |
+
'LHandM1': 3 ,
|
| 1476 |
+
'LHandM2': 3 ,
|
| 1477 |
+
'LHandM3': 3 ,
|
| 1478 |
+
'LHandR': 3 ,
|
| 1479 |
+
'LHandR1': 3 ,
|
| 1480 |
+
'LHandR2': 3 ,
|
| 1481 |
+
'LHandR3': 3 ,
|
| 1482 |
+
'LHandP': 3 ,
|
| 1483 |
+
'LHandP1': 3 ,
|
| 1484 |
+
'LHandP2': 3 ,
|
| 1485 |
+
'LHandP3': 3 ,
|
| 1486 |
+
'LHandI': 3 ,
|
| 1487 |
+
'LHandI1': 3 ,
|
| 1488 |
+
'LHandI2': 3 ,
|
| 1489 |
+
'LHandI3': 3 ,
|
| 1490 |
+
'LHandT1': 3 ,
|
| 1491 |
+
'LHandT2': 3 ,
|
| 1492 |
+
'LHandT3': 3 ,},
|
| 1493 |
+
}
|
| 1494 |
+
|
| 1495 |
+
|
| 1496 |
+
class FIDCalculator(object):
|
| 1497 |
+
'''
|
| 1498 |
+
todo
|
| 1499 |
+
'''
|
| 1500 |
+
def __init__(self):
|
| 1501 |
+
self.gt_rot = None # pandas dataframe for n frames * joints * 6
|
| 1502 |
+
self.gt_pos = None # n frames * (joints + 13) * 3
|
| 1503 |
+
self.op_rot = None # pandas dataframe for n frames * joints * 6
|
| 1504 |
+
self.op_pos = None # n frames * (joints + 13) * 3
|
| 1505 |
+
|
| 1506 |
+
|
| 1507 |
+
def load(self, path, load_type, save_pos=False):
|
| 1508 |
+
'''
|
| 1509 |
+
select gt or op for load_type
|
| 1510 |
+
'''
|
| 1511 |
+
parser = BVHParser()
|
| 1512 |
+
parsed_data = parser.parse(path)
|
| 1513 |
+
if load_type == 'gt':
|
| 1514 |
+
self.gt_rot = parsed_data.values
|
| 1515 |
+
elif load_type == 'op':
|
| 1516 |
+
self.op_rot = parsed_data.values
|
| 1517 |
+
else: print('error, select gt or op for load_type')
|
| 1518 |
+
|
| 1519 |
+
if save_pos:
|
| 1520 |
+
mp = MocapParameterizer('position')
|
| 1521 |
+
positions = mp.fit_transform([parsed_data])
|
| 1522 |
+
if load_type == 'gt':
|
| 1523 |
+
self.gt_pos = positions[0].values
|
| 1524 |
+
elif load_type == 'op':
|
| 1525 |
+
self.op_pos = positions[0].values
|
| 1526 |
+
else: print('error, select gt or op for load_type')
|
| 1527 |
+
|
| 1528 |
+
|
| 1529 |
+
def _joint_selector(self, selected_joints, ori_data):
|
| 1530 |
+
selected_data = pd.DataFrame(columns=[])
|
| 1531 |
+
|
| 1532 |
+
for joint_name in selected_joints:
|
| 1533 |
+
selected_data[joint_name] = ori_data[joint_name]
|
| 1534 |
+
return selected_data.to_numpy()
|
| 1535 |
+
|
| 1536 |
+
|
| 1537 |
+
def cal_vol(self, dtype):
|
| 1538 |
+
if dtype == 'pos':
|
| 1539 |
+
gt = self.gt_pos
|
| 1540 |
+
op = self.op_pos
|
| 1541 |
+
else:
|
| 1542 |
+
gt = self.gt_rot
|
| 1543 |
+
op = self.op_rot
|
| 1544 |
+
|
| 1545 |
+
gt_v = gt.to_numpy()[1:, :] - gt.to_numpy()[0:-1, :]
|
| 1546 |
+
op_v = op.to_numpy()[1:, :] - op.to_numpy()[0:-1, :]
|
| 1547 |
+
if dtype == 'pos':
|
| 1548 |
+
self.gt_vol_pos = pd.DataFrame(gt_v, columns = gt.columns.tolist())
|
| 1549 |
+
self.op_vol_pos = pd.DataFrame(op_v, columns = gt.columns.tolist())
|
| 1550 |
+
else:
|
| 1551 |
+
self.gt_vol_rot = pd.DataFrame(gt_v, columns = gt.columns.tolist())
|
| 1552 |
+
self.op_vol_rot = pd.DataFrame(op_v, columns = gt.columns.tolist())
|
| 1553 |
+
|
| 1554 |
+
|
| 1555 |
+
@staticmethod
|
| 1556 |
+
def frechet_distance(samples_A, samples_B):
|
| 1557 |
+
A_mu = np.mean(samples_A, axis=0)
|
| 1558 |
+
A_sigma = np.cov(samples_A, rowvar=False)
|
| 1559 |
+
B_mu = np.mean(samples_B, axis=0)
|
| 1560 |
+
B_sigma = np.cov(samples_B, rowvar=False)
|
| 1561 |
+
try:
|
| 1562 |
+
frechet_dist = FIDCalculator.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
|
| 1563 |
+
except ValueError:
|
| 1564 |
+
frechet_dist = 1e+10
|
| 1565 |
+
return frechet_dist
|
| 1566 |
+
|
| 1567 |
+
|
| 1568 |
+
@staticmethod
|
| 1569 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 1570 |
+
""" from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
|
| 1571 |
+
"""Numpy implementation of the Frechet Distance.
|
| 1572 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 1573 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 1574 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 1575 |
+
Stable version by Dougal J. Sutherland.
|
| 1576 |
+
Params:
|
| 1577 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 1578 |
+
inception net (like returned by the function 'get_predictions')
|
| 1579 |
+
for generated samples.
|
| 1580 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 1581 |
+
representative data set.
|
| 1582 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 1583 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 1584 |
+
representative data set.
|
| 1585 |
+
Returns:
|
| 1586 |
+
-- : The Frechet Distance.
|
| 1587 |
+
"""
|
| 1588 |
+
|
| 1589 |
+
mu1 = np.atleast_1d(mu1)
|
| 1590 |
+
mu2 = np.atleast_1d(mu2)
|
| 1591 |
+
#print(mu1[0], mu2[0])
|
| 1592 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 1593 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 1594 |
+
#print(sigma1[0], sigma2[0])
|
| 1595 |
+
assert mu1.shape == mu2.shape, \
|
| 1596 |
+
'Training and test mean vectors have different lengths'
|
| 1597 |
+
assert sigma1.shape == sigma2.shape, \
|
| 1598 |
+
'Training and test covariances have different dimensions'
|
| 1599 |
+
|
| 1600 |
+
diff = mu1 - mu2
|
| 1601 |
+
|
| 1602 |
+
# Product might be almost singular
|
| 1603 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 1604 |
+
#print(diff, covmean[0])
|
| 1605 |
+
if not np.isfinite(covmean).all():
|
| 1606 |
+
msg = ('fid calculation produces singular product; '
|
| 1607 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 1608 |
+
print(msg)
|
| 1609 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 1610 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 1611 |
+
|
| 1612 |
+
# Numerical error might give slight imaginary component
|
| 1613 |
+
if np.iscomplexobj(covmean):
|
| 1614 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 1615 |
+
m = np.max(np.abs(covmean.imag))
|
| 1616 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 1617 |
+
covmean = covmean.real
|
| 1618 |
+
|
| 1619 |
+
tr_covmean = np.trace(covmean)
|
| 1620 |
+
|
| 1621 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
| 1622 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
def calculate_fid(self, cal_type, joint_type, high_level_opt):
|
| 1626 |
+
|
| 1627 |
+
if cal_type == 'pos':
|
| 1628 |
+
if self.gt_pos.shape != self.op_pos.shape:
|
| 1629 |
+
min_val = min(self.gt_pos.shape[0],self.op_pos.shape[0])
|
| 1630 |
+
gt = self.gt_pos[:min_val]
|
| 1631 |
+
op = self.op_pos[:min_val]
|
| 1632 |
+
else:
|
| 1633 |
+
gt = self.gt_pos
|
| 1634 |
+
op = self.op_pos
|
| 1635 |
+
full_body = gt.columns.tolist()
|
| 1636 |
+
elif cal_type == 'rot':
|
| 1637 |
+
if self.gt_rot.shape != self.op_rot.shape:
|
| 1638 |
+
min_val = min(self.gt_rot.shape[0],self.op_rot.shape[0])
|
| 1639 |
+
gt = self.gt_rot[:min_val]
|
| 1640 |
+
op = self.op_rot[:min_val]
|
| 1641 |
+
else:
|
| 1642 |
+
gt = self.gt_rot
|
| 1643 |
+
op = self.op_rot
|
| 1644 |
+
full_body_with_offset = gt.columns.tolist()
|
| 1645 |
+
full_body = [o for o in full_body_with_offset if ('position' not in o)]
|
| 1646 |
+
elif cal_type == 'pos_vol':
|
| 1647 |
+
assert self.gt_vol_pos.shape == self.op_vol_pos.shape
|
| 1648 |
+
gt = self.gt_vol_pos
|
| 1649 |
+
op = self.op_vol_pos
|
| 1650 |
+
full_body_with_offset = gt.columns.tolist()
|
| 1651 |
+
full_body = gt.columns.tolist()
|
| 1652 |
+
elif cal_type == 'rot_vol':
|
| 1653 |
+
assert self.gt_vol_rot.shape == self.op_vol_rot.shape
|
| 1654 |
+
gt = self.gt_vol_rot
|
| 1655 |
+
op = self.op_vol_rot
|
| 1656 |
+
full_body_with_offset = gt.columns.tolist()
|
| 1657 |
+
full_body = [o for o in full_body_with_offset if ('position' not in o)]
|
| 1658 |
+
#print(f'full_body contains {len(full_body)//3} joints')
|
| 1659 |
+
|
| 1660 |
+
if joint_type == 'full_upper_body':
|
| 1661 |
+
selected_body = [o for o in full_body if ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)]
|
| 1662 |
+
elif joint_type == 'upper_body':
|
| 1663 |
+
selected_body = [o for o in full_body if ('Hand' not in o) and ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)]
|
| 1664 |
+
elif joint_type == 'fingers':
|
| 1665 |
+
selected_body = [o for o in full_body if ('Hand' in o)]
|
| 1666 |
+
elif joint_type == 'indivdual':
|
| 1667 |
+
pass
|
| 1668 |
+
else: print('error, plz select correct joint type')
|
| 1669 |
+
#print(f'calculate fid for {len(selected_body)//3} joints')
|
| 1670 |
+
|
| 1671 |
+
gt = self._joint_selector(selected_body, gt)
|
| 1672 |
+
op = self._joint_selector(selected_body, op)
|
| 1673 |
+
|
| 1674 |
+
if high_level_opt == 'fid':
|
| 1675 |
+
fid = FIDCalculator.frechet_distance(gt, op)
|
| 1676 |
+
return fid
|
| 1677 |
+
elif high_level_opt == 'var':
|
| 1678 |
+
var_gt = gt.var()
|
| 1679 |
+
var_op = op.var()
|
| 1680 |
+
return var_gt, var_op
|
| 1681 |
+
elif high_level_opt == 'mean':
|
| 1682 |
+
mean_gt = gt.mean()
|
| 1683 |
+
mean_op = op.mean()
|
| 1684 |
+
return mean_gt, mean_op
|
| 1685 |
+
else: return 0
|
| 1686 |
+
|
| 1687 |
+
|
| 1688 |
+
def result2target_vis(pose_version, res_bvhlist, save_path, demo_name, verbose=True):
|
| 1689 |
+
if "trinity" in pose_version:
|
| 1690 |
+
ori_list = joints_list[pose_version[6:-4]]
|
| 1691 |
+
target_list = joints_list[pose_version[6:]]
|
| 1692 |
+
file_content_length = 336
|
| 1693 |
+
elif "beat" in pose_version or "spine_neck_141" in pose_version:
|
| 1694 |
+
ori_list = joints_list["beat_joints"]
|
| 1695 |
+
target_list = joints_list["spine_neck_141"]
|
| 1696 |
+
file_content_length = 431
|
| 1697 |
+
elif "yostar" in pose_version:
|
| 1698 |
+
ori_list = joints_list["yostar"]
|
| 1699 |
+
target_list = joints_list[pose_version]
|
| 1700 |
+
file_content_length = 1056
|
| 1701 |
+
else:
|
| 1702 |
+
ori_list = joints_list["japanese_joints"]
|
| 1703 |
+
target_list = joints_list[pose_version]
|
| 1704 |
+
file_content_length = 366
|
| 1705 |
+
|
| 1706 |
+
bvh_files_dirs = sorted(glob.glob(f'{res_bvhlist}*.bvh'), key=str)
|
| 1707 |
+
#test_seq_list = os.list_dir(demo_name).sort()
|
| 1708 |
+
|
| 1709 |
+
counter = 0
|
| 1710 |
+
if not os.path.exists(save_path):
|
| 1711 |
+
os.makedirs(save_path)
|
| 1712 |
+
for i, bvh_file_dir in enumerate(bvh_files_dirs):
|
| 1713 |
+
short_name = bvh_file_dir.split("/")[-1][11:]
|
| 1714 |
+
#print(short_name)
|
| 1715 |
+
wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+')
|
| 1716 |
+
with open(f"{demo_name}{short_name}",'r') as pose_data_pre:
|
| 1717 |
+
pose_data_pre_file = pose_data_pre.readlines()
|
| 1718 |
+
for j, line in enumerate(pose_data_pre_file[0:file_content_length]):
|
| 1719 |
+
wirte_file.write(line)
|
| 1720 |
+
offset_data = pose_data_pre_file[file_content_length]
|
| 1721 |
+
offset_data = np.fromstring(offset_data, dtype=float, sep=' ')
|
| 1722 |
+
wirte_file.close()
|
| 1723 |
+
|
| 1724 |
+
wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'r')
|
| 1725 |
+
ori_lines = wirte_file.readlines()
|
| 1726 |
+
with open(bvh_file_dir, 'r') as pose_data:
|
| 1727 |
+
pose_data_file = pose_data.readlines()
|
| 1728 |
+
ori_lines[file_content_length-2] = 'Frames: ' + str(len(pose_data_file)-1) + '\n'
|
| 1729 |
+
wirte_file.close()
|
| 1730 |
+
|
| 1731 |
+
wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+')
|
| 1732 |
+
wirte_file.writelines(i for i in ori_lines[:file_content_length])
|
| 1733 |
+
wirte_file.close()
|
| 1734 |
+
|
| 1735 |
+
with open(os.path.join(save_path, f'res_{short_name}'),'a+') as wirte_file:
|
| 1736 |
+
with open(bvh_file_dir, 'r') as pose_data:
|
| 1737 |
+
data_each_file = []
|
| 1738 |
+
pose_data_file = pose_data.readlines()
|
| 1739 |
+
for j, line in enumerate(pose_data_file):
|
| 1740 |
+
if not j:
|
| 1741 |
+
pass
|
| 1742 |
+
else:
|
| 1743 |
+
data = np.fromstring(line, dtype=float, sep=' ')
|
| 1744 |
+
data_rotation = offset_data.copy()
|
| 1745 |
+
for iii, (k, v) in enumerate(target_list.items()): # here is 147 rotations by 3
|
| 1746 |
+
#print(data_rotation[ori_list[k][1]-v:ori_list[k][1]], data[iii*3:iii*3+3])
|
| 1747 |
+
data_rotation[ori_list[k][1]-v:ori_list[k][1]] = data[iii*3:iii*3+3]
|
| 1748 |
+
data_each_file.append(data_rotation)
|
| 1749 |
+
|
| 1750 |
+
for line_data in data_each_file:
|
| 1751 |
+
line_data = np.array2string(line_data, max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
|
| 1752 |
+
wirte_file.write(line_data[1:-2]+'\n')
|
| 1753 |
+
|
| 1754 |
+
counter += 1
|
| 1755 |
+
if verbose:
|
| 1756 |
+
logger.info('data_shape:', data_rotation.shape, 'process:', counter, '/', len(bvh_files_dirs))
|
dataloaders/mix_sep.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import math
|
| 4 |
+
import shutil
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb as lmdb
|
| 7 |
+
import textgrid as tg
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import glob
|
| 11 |
+
import json
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
#import pyarrow
|
| 18 |
+
import pickle
|
| 19 |
+
import librosa
|
| 20 |
+
import smplx
|
| 21 |
+
import glob
|
| 22 |
+
|
| 23 |
+
from .build_vocab import Vocab
|
| 24 |
+
from .utils.audio_features import Wav2Vec2Model
|
| 25 |
+
from .data_tools import joints_list
|
| 26 |
+
from .utils import rotation_conversions as rc
|
| 27 |
+
from .utils import other_tools
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CustomDataset(Dataset):
|
| 31 |
+
def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True):
|
| 32 |
+
self.args = args
|
| 33 |
+
self.loader_type = loader_type
|
| 34 |
+
|
| 35 |
+
self.rank = 0
|
| 36 |
+
self.ori_stride = self.args.stride
|
| 37 |
+
self.ori_length = self.args.pose_length
|
| 38 |
+
|
| 39 |
+
self.ori_joint_list = joints_list[self.args.ori_joints]
|
| 40 |
+
self.tar_joint_list = joints_list[self.args.tar_joints]
|
| 41 |
+
if 'smplx' in self.args.pose_rep:
|
| 42 |
+
self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3)
|
| 43 |
+
self.joints = len(list(self.tar_joint_list.keys()))
|
| 44 |
+
for joint_name in self.tar_joint_list:
|
| 45 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 46 |
+
else:
|
| 47 |
+
self.joints = len(list(self.ori_joint_list.keys()))+1
|
| 48 |
+
self.joint_mask = np.zeros(self.joints*3)
|
| 49 |
+
for joint_name in self.tar_joint_list:
|
| 50 |
+
if joint_name == "Hips":
|
| 51 |
+
self.joint_mask[3:6] = 1
|
| 52 |
+
else:
|
| 53 |
+
self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1
|
| 54 |
+
# select trainable joints
|
| 55 |
+
|
| 56 |
+
split_rule = pd.read_csv(args.data_path+"train_test_split.csv")
|
| 57 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 58 |
+
if args.additional_data and loader_type == 'train':
|
| 59 |
+
split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 60 |
+
#self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 61 |
+
self.selected_file = pd.concat([self.selected_file, split_b])
|
| 62 |
+
if self.selected_file.empty:
|
| 63 |
+
logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead")
|
| 64 |
+
self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))]
|
| 65 |
+
self.selected_file = self.selected_file.iloc[0:8]
|
| 66 |
+
self.data_dir = args.data_path
|
| 67 |
+
self.beatx_during_time = 0
|
| 68 |
+
|
| 69 |
+
if loader_type == "test":
|
| 70 |
+
self.args.multi_length_training = [1.0]
|
| 71 |
+
self.max_length = int(args.pose_length * self.args.multi_length_training[-1])
|
| 72 |
+
self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr)
|
| 73 |
+
if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr:
|
| 74 |
+
self.max_audio_pre_len = self.args.test_length*self.args.audio_sr
|
| 75 |
+
preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if build_cache and self.rank == 0:
|
| 79 |
+
self.build_cache(preloaded_dir)
|
| 80 |
+
self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False)
|
| 81 |
+
with self.lmdb_env.begin() as txn:
|
| 82 |
+
self.n_samples = txn.stat()["entries"]
|
| 83 |
+
|
| 84 |
+
self.norm = True
|
| 85 |
+
self.mean = np.load('./mean_std/beatx_2_330_mean.npy')
|
| 86 |
+
self.std = np.load('./mean_std/beatx_2_330_std.npy')
|
| 87 |
+
|
| 88 |
+
self.trans_mean = np.load('./mean_std/beatx_2_trans_mean.npy')
|
| 89 |
+
self.trans_std = np.load('./mean_std/beatx_2_trans_std.npy')
|
| 90 |
+
|
| 91 |
+
def build_cache(self, preloaded_dir):
|
| 92 |
+
logger.info(f"Audio bit rate: {self.args.audio_fps}")
|
| 93 |
+
logger.info("Reading data '{}'...".format(self.data_dir))
|
| 94 |
+
logger.info("Creating the dataset cache...")
|
| 95 |
+
if self.args.new_cache:
|
| 96 |
+
if os.path.exists(preloaded_dir):
|
| 97 |
+
shutil.rmtree(preloaded_dir)
|
| 98 |
+
if os.path.exists(preloaded_dir):
|
| 99 |
+
logger.info("Found the cache {}".format(preloaded_dir))
|
| 100 |
+
elif self.loader_type == "test":
|
| 101 |
+
self.cache_generation(
|
| 102 |
+
preloaded_dir, True,
|
| 103 |
+
0, 0,
|
| 104 |
+
is_test=True)
|
| 105 |
+
else:
|
| 106 |
+
self.cache_generation(
|
| 107 |
+
preloaded_dir, self.args.disable_filtering,
|
| 108 |
+
self.args.clean_first_seconds, self.args.clean_final_seconds,
|
| 109 |
+
is_test=False)
|
| 110 |
+
logger.info(f"BEATX during time is {self.beatx_during_time}s !")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return self.n_samples
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False):
|
| 118 |
+
self.n_out_samples = 0
|
| 119 |
+
# create db for samples
|
| 120 |
+
if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir)
|
| 121 |
+
dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G
|
| 122 |
+
n_filtered_out = defaultdict(int)
|
| 123 |
+
|
| 124 |
+
for index, file_name in self.selected_file.iterrows():
|
| 125 |
+
f_name = file_name["id"]
|
| 126 |
+
ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh"
|
| 127 |
+
pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext
|
| 128 |
+
pose_each_file = []
|
| 129 |
+
trans_each_file = []
|
| 130 |
+
trans_v_each_file = []
|
| 131 |
+
shape_each_file = []
|
| 132 |
+
audio_each_file = []
|
| 133 |
+
facial_each_file = []
|
| 134 |
+
word_each_file = []
|
| 135 |
+
emo_each_file = []
|
| 136 |
+
sem_each_file = []
|
| 137 |
+
vid_each_file = []
|
| 138 |
+
id_pose = f_name #1_wayne_0_1_1
|
| 139 |
+
|
| 140 |
+
logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue"))
|
| 141 |
+
if "smplx" in self.args.pose_rep:
|
| 142 |
+
pose_data = np.load(pose_file, allow_pickle=True)
|
| 143 |
+
assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30'
|
| 144 |
+
stride = int(30/self.args.pose_fps)
|
| 145 |
+
pose_each_file = pose_data["poses"][::stride] * self.joint_mask
|
| 146 |
+
pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)]
|
| 147 |
+
|
| 148 |
+
self.beatx_during_time += pose_each_file.shape[0]/30
|
| 149 |
+
trans_each_file = pose_data["trans"][::stride]
|
| 150 |
+
trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0]
|
| 151 |
+
trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2]
|
| 152 |
+
trans_v_each_file = np.zeros_like(trans_each_file)
|
| 153 |
+
trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0]
|
| 154 |
+
trans_v_each_file[0,0] = trans_v_each_file[1,0]
|
| 155 |
+
trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2]
|
| 156 |
+
trans_v_each_file[0,2] = trans_v_each_file[1,2]
|
| 157 |
+
trans_v_each_file[:,1] = trans_each_file[:,1]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0)
|
| 161 |
+
if self.args.facial_rep is not None:
|
| 162 |
+
logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #")
|
| 163 |
+
facial_each_file = pose_data["expressions"][::stride]
|
| 164 |
+
if self.args.facial_norm:
|
| 165 |
+
facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial
|
| 166 |
+
|
| 167 |
+
if self.args.id_rep is not None:
|
| 168 |
+
vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0)
|
| 169 |
+
|
| 170 |
+
filtered_result = self._sample_from_clip(
|
| 171 |
+
dst_lmdb_env,
|
| 172 |
+
pose_each_file, trans_each_file,trans_v_each_file, shape_each_file, facial_each_file,
|
| 173 |
+
vid_each_file,
|
| 174 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 175 |
+
)
|
| 176 |
+
for type in filtered_result.keys():
|
| 177 |
+
n_filtered_out[type] += filtered_result[type]
|
| 178 |
+
|
| 179 |
+
with dst_lmdb_env.begin() as txn:
|
| 180 |
+
logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan"))
|
| 181 |
+
n_total_filtered = 0
|
| 182 |
+
for type, n_filtered in n_filtered_out.items():
|
| 183 |
+
logger.info("{}: {}".format(type, n_filtered))
|
| 184 |
+
n_total_filtered += n_filtered
|
| 185 |
+
logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format(
|
| 186 |
+
n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan"))
|
| 187 |
+
dst_lmdb_env.sync()
|
| 188 |
+
dst_lmdb_env.close()
|
| 189 |
+
|
| 190 |
+
def _sample_from_clip(
|
| 191 |
+
self, dst_lmdb_env, pose_each_file, trans_each_file, trans_v_each_file, shape_each_file, facial_each_file,
|
| 192 |
+
vid_each_file,
|
| 193 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 194 |
+
):
|
| 195 |
+
"""
|
| 196 |
+
for data cleaning, we ignore the data for first and final n s
|
| 197 |
+
for test, we return all data
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s
|
| 201 |
+
#print(round_seconds_skeleton)
|
| 202 |
+
|
| 203 |
+
clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s
|
| 204 |
+
clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000]
|
| 205 |
+
clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
for ratio in self.args.multi_length_training:
|
| 209 |
+
if is_test:# stride = length for test
|
| 210 |
+
cut_length = clip_e_f_pose - clip_s_f_pose
|
| 211 |
+
self.args.stride = cut_length
|
| 212 |
+
self.max_length = cut_length
|
| 213 |
+
else:
|
| 214 |
+
self.args.stride = int(ratio*self.ori_stride)
|
| 215 |
+
cut_length = int(self.ori_length*ratio)
|
| 216 |
+
|
| 217 |
+
num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1
|
| 218 |
+
logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}")
|
| 219 |
+
logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
n_filtered_out = defaultdict(int)
|
| 223 |
+
sample_pose_list = []
|
| 224 |
+
sample_face_list = []
|
| 225 |
+
sample_shape_list = []
|
| 226 |
+
sample_vid_list = []
|
| 227 |
+
sample_trans_list = []
|
| 228 |
+
sample_trans_v_list = []
|
| 229 |
+
|
| 230 |
+
for i in range(num_subdivision): # cut into around 2s chip, (self npose)
|
| 231 |
+
start_idx = clip_s_f_pose + i * self.args.stride
|
| 232 |
+
fin_idx = start_idx + cut_length
|
| 233 |
+
sample_pose = pose_each_file[start_idx:fin_idx]
|
| 234 |
+
sample_trans = trans_each_file[start_idx:fin_idx]
|
| 235 |
+
sample_trans_v = trans_v_each_file[start_idx:fin_idx]
|
| 236 |
+
sample_shape = shape_each_file[start_idx:fin_idx]
|
| 237 |
+
sample_face = facial_each_file[start_idx:fin_idx]
|
| 238 |
+
# print(sample_pose.shape)
|
| 239 |
+
sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1])
|
| 240 |
+
|
| 241 |
+
if sample_pose.any() != None:
|
| 242 |
+
sample_pose_list.append(sample_pose)
|
| 243 |
+
|
| 244 |
+
sample_shape_list.append(sample_shape)
|
| 245 |
+
|
| 246 |
+
sample_vid_list.append(sample_vid)
|
| 247 |
+
sample_face_list.append(sample_face)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
sample_trans_list.append(sample_trans)
|
| 251 |
+
sample_trans_v_list.append(sample_trans_v)
|
| 252 |
+
|
| 253 |
+
if len(sample_pose_list) > 0:
|
| 254 |
+
with dst_lmdb_env.begin(write=True) as txn:
|
| 255 |
+
for pose, shape, face, vid, trans,trans_v in zip(
|
| 256 |
+
sample_pose_list,
|
| 257 |
+
sample_shape_list,
|
| 258 |
+
sample_face_list,
|
| 259 |
+
sample_vid_list,
|
| 260 |
+
sample_trans_list,
|
| 261 |
+
sample_trans_v_list,
|
| 262 |
+
):
|
| 263 |
+
k = "{:005}".format(self.n_out_samples).encode("ascii")
|
| 264 |
+
v = [pose , shape, face, vid, trans,trans_v]
|
| 265 |
+
v = pickle.dumps(v,5)
|
| 266 |
+
txn.put(k, v)
|
| 267 |
+
self.n_out_samples += 1
|
| 268 |
+
return n_filtered_out
|
| 269 |
+
|
| 270 |
+
def __getitem__(self, idx):
|
| 271 |
+
with self.lmdb_env.begin(write=False) as txn:
|
| 272 |
+
key = "{:005}".format(idx).encode("ascii")
|
| 273 |
+
sample = txn.get(key)
|
| 274 |
+
sample = pickle.loads(sample)
|
| 275 |
+
tar_pose, in_shape, tar_face, vid, trans,trans_v = sample
|
| 276 |
+
tar_pose = torch.from_numpy(tar_pose).float()
|
| 277 |
+
tar_face = torch.from_numpy(tar_face).float()
|
| 278 |
+
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(-1, 55, 3))
|
| 279 |
+
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(-1, 55*6)
|
| 280 |
+
|
| 281 |
+
if self.norm:
|
| 282 |
+
tar_pose = (tar_pose - self.mean) / self.std
|
| 283 |
+
trans_v = (trans_v-self.trans_mean)/self.trans_std
|
| 284 |
+
|
| 285 |
+
if self.loader_type == "test":
|
| 286 |
+
tar_pose = tar_pose.float()
|
| 287 |
+
trans = torch.from_numpy(trans).float()
|
| 288 |
+
trans_v = torch.from_numpy(trans_v).float()
|
| 289 |
+
vid = torch.from_numpy(vid).float()
|
| 290 |
+
in_shape = torch.from_numpy(in_shape).float()
|
| 291 |
+
tar_pose = torch.cat([tar_pose, trans_v], dim=1)
|
| 292 |
+
tar_pose = torch.cat([tar_pose, tar_face], dim=1)
|
| 293 |
+
else:
|
| 294 |
+
in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float()
|
| 295 |
+
trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float()
|
| 296 |
+
trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float()
|
| 297 |
+
vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float()
|
| 298 |
+
tar_pose = tar_pose.reshape((tar_pose.shape[0], -1)).float()
|
| 299 |
+
tar_pose = torch.cat([tar_pose, trans_v], dim=1)
|
| 300 |
+
tar_pose = torch.cat([tar_pose, tar_face], dim=1)
|
| 301 |
+
return tar_pose
|
dataloaders/pymo/Quaternions.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Quaternions:
|
| 4 |
+
"""
|
| 5 |
+
Quaternions is a wrapper around a numpy ndarray
|
| 6 |
+
that allows it to act as if it were an narray of
|
| 7 |
+
a quaternion data type.
|
| 8 |
+
|
| 9 |
+
Therefore addition, subtraction, multiplication,
|
| 10 |
+
division, negation, absolute, are all defined
|
| 11 |
+
in terms of quaternion operations such as quaternion
|
| 12 |
+
multiplication.
|
| 13 |
+
|
| 14 |
+
This allows for much neater code and many routines
|
| 15 |
+
which conceptually do the same thing to be written
|
| 16 |
+
in the same way for point data and for rotation data.
|
| 17 |
+
|
| 18 |
+
The Quaternions class has been desgined such that it
|
| 19 |
+
should support broadcasting and slicing in all of the
|
| 20 |
+
usual ways.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, qs):
|
| 24 |
+
if isinstance(qs, np.ndarray):
|
| 25 |
+
|
| 26 |
+
if len(qs.shape) == 1: qs = np.array([qs])
|
| 27 |
+
self.qs = qs
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
if isinstance(qs, Quaternions):
|
| 31 |
+
self.qs = qs.qs
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
|
| 35 |
+
|
| 36 |
+
def __str__(self): return "Quaternions("+ str(self.qs) + ")"
|
| 37 |
+
def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
|
| 38 |
+
|
| 39 |
+
""" Helper Methods for Broadcasting and Data extraction """
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def _broadcast(cls, sqs, oqs, scalar=False):
|
| 43 |
+
|
| 44 |
+
if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])
|
| 45 |
+
|
| 46 |
+
ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
|
| 47 |
+
os = np.array(oqs.shape)
|
| 48 |
+
|
| 49 |
+
if len(ss) != len(os):
|
| 50 |
+
raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
|
| 51 |
+
|
| 52 |
+
if np.all(ss == os): return sqs, oqs
|
| 53 |
+
|
| 54 |
+
if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
|
| 55 |
+
raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
|
| 56 |
+
|
| 57 |
+
sqsn, oqsn = sqs.copy(), oqs.copy()
|
| 58 |
+
|
| 59 |
+
for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)
|
| 60 |
+
for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)
|
| 61 |
+
|
| 62 |
+
return sqsn, oqsn
|
| 63 |
+
|
| 64 |
+
""" Adding Quaterions is just Defined as Multiplication """
|
| 65 |
+
|
| 66 |
+
def __add__(self, other): return self * other
|
| 67 |
+
def __sub__(self, other): return self / other
|
| 68 |
+
|
| 69 |
+
""" Quaterion Multiplication """
|
| 70 |
+
|
| 71 |
+
def __mul__(self, other):
|
| 72 |
+
"""
|
| 73 |
+
Quaternion multiplication has three main methods.
|
| 74 |
+
|
| 75 |
+
When multiplying a Quaternions array by Quaternions
|
| 76 |
+
normal quaternion multiplication is performed.
|
| 77 |
+
|
| 78 |
+
When multiplying a Quaternions array by a vector
|
| 79 |
+
array of the same shape, where the last axis is 3,
|
| 80 |
+
it is assumed to be a Quaternion by 3D-Vector
|
| 81 |
+
multiplication and the 3D-Vectors are rotated
|
| 82 |
+
in space by the Quaternions.
|
| 83 |
+
|
| 84 |
+
When multipplying a Quaternions array by a scalar
|
| 85 |
+
or vector of different shape it is assumed to be
|
| 86 |
+
a Quaternions by Scalars multiplication and the
|
| 87 |
+
Quaternions are scaled using Slerp and the identity
|
| 88 |
+
quaternions.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
""" If Quaternions type do Quaternions * Quaternions """
|
| 92 |
+
if isinstance(other, Quaternions):
|
| 93 |
+
|
| 94 |
+
sqs, oqs = Quaternions._broadcast(self.qs, other.qs)
|
| 95 |
+
|
| 96 |
+
q0 = sqs[...,0]; q1 = sqs[...,1];
|
| 97 |
+
q2 = sqs[...,2]; q3 = sqs[...,3];
|
| 98 |
+
r0 = oqs[...,0]; r1 = oqs[...,1];
|
| 99 |
+
r2 = oqs[...,2]; r3 = oqs[...,3];
|
| 100 |
+
|
| 101 |
+
qs = np.empty(sqs.shape)
|
| 102 |
+
qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
|
| 103 |
+
qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
|
| 104 |
+
qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
|
| 105 |
+
qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
|
| 106 |
+
|
| 107 |
+
return Quaternions(qs)
|
| 108 |
+
|
| 109 |
+
""" If array type do Quaternions * Vectors """
|
| 110 |
+
if isinstance(other, np.ndarray) and other.shape[-1] == 3:
|
| 111 |
+
vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))
|
| 112 |
+
return (self * (vs * -self)).imaginaries
|
| 113 |
+
|
| 114 |
+
""" If float do Quaternions * Scalars """
|
| 115 |
+
if isinstance(other, np.ndarray) or isinstance(other, float):
|
| 116 |
+
return Quaternions.slerp(Quaternions.id_like(self), self, other)
|
| 117 |
+
|
| 118 |
+
raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
|
| 119 |
+
|
| 120 |
+
def __div__(self, other):
|
| 121 |
+
"""
|
| 122 |
+
When a Quaternion type is supplied, division is defined
|
| 123 |
+
as multiplication by the inverse of that Quaternion.
|
| 124 |
+
|
| 125 |
+
When a scalar or vector is supplied it is defined
|
| 126 |
+
as multiplicaion of one over the supplied value.
|
| 127 |
+
Essentially a scaling.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
if isinstance(other, Quaternions): return self * (-other)
|
| 131 |
+
if isinstance(other, np.ndarray): return self * (1.0 / other)
|
| 132 |
+
if isinstance(other, float): return self * (1.0 / other)
|
| 133 |
+
raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
|
| 134 |
+
|
| 135 |
+
def __eq__(self, other): return self.qs == other.qs
|
| 136 |
+
def __ne__(self, other): return self.qs != other.qs
|
| 137 |
+
|
| 138 |
+
def __neg__(self):
|
| 139 |
+
""" Invert Quaternions """
|
| 140 |
+
return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
|
| 141 |
+
|
| 142 |
+
def __abs__(self):
|
| 143 |
+
""" Unify Quaternions To Single Pole """
|
| 144 |
+
qabs = self.normalized().copy()
|
| 145 |
+
top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)
|
| 146 |
+
bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)
|
| 147 |
+
qabs.qs[top < bot] = -qabs.qs[top < bot]
|
| 148 |
+
return qabs
|
| 149 |
+
|
| 150 |
+
def __iter__(self): return iter(self.qs)
|
| 151 |
+
def __len__(self): return len(self.qs)
|
| 152 |
+
|
| 153 |
+
def __getitem__(self, k): return Quaternions(self.qs[k])
|
| 154 |
+
def __setitem__(self, k, v): self.qs[k] = v.qs
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def lengths(self):
|
| 158 |
+
return np.sum(self.qs**2.0, axis=-1)**0.5
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def reals(self):
|
| 162 |
+
return self.qs[...,0]
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def imaginaries(self):
|
| 166 |
+
return self.qs[...,1:4]
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def shape(self): return self.qs.shape[:-1]
|
| 170 |
+
|
| 171 |
+
def repeat(self, n, **kwargs):
|
| 172 |
+
return Quaternions(self.qs.repeat(n, **kwargs))
|
| 173 |
+
|
| 174 |
+
def normalized(self):
|
| 175 |
+
return Quaternions(self.qs / self.lengths[...,np.newaxis])
|
| 176 |
+
|
| 177 |
+
def log(self):
|
| 178 |
+
norm = abs(self.normalized())
|
| 179 |
+
imgs = norm.imaginaries
|
| 180 |
+
lens = np.sqrt(np.sum(imgs**2, axis=-1))
|
| 181 |
+
lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
|
| 182 |
+
return imgs * lens[...,np.newaxis]
|
| 183 |
+
|
| 184 |
+
def constrained(self, axis):
|
| 185 |
+
|
| 186 |
+
rl = self.reals
|
| 187 |
+
im = np.sum(axis * self.imaginaries, axis=-1)
|
| 188 |
+
|
| 189 |
+
t1 = -2 * np.arctan2(rl, im) + np.pi
|
| 190 |
+
t2 = -2 * np.arctan2(rl, im) - np.pi
|
| 191 |
+
|
| 192 |
+
top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))
|
| 193 |
+
bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))
|
| 194 |
+
img = self.dot(top) > self.dot(bot)
|
| 195 |
+
|
| 196 |
+
ret = top.copy()
|
| 197 |
+
ret[ img] = top[ img]
|
| 198 |
+
ret[~img] = bot[~img]
|
| 199 |
+
return ret
|
| 200 |
+
|
| 201 |
+
def constrained_x(self): return self.constrained(np.array([1,0,0]))
|
| 202 |
+
def constrained_y(self): return self.constrained(np.array([0,1,0]))
|
| 203 |
+
def constrained_z(self): return self.constrained(np.array([0,0,1]))
|
| 204 |
+
|
| 205 |
+
def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
|
| 206 |
+
|
| 207 |
+
def copy(self): return Quaternions(np.copy(self.qs))
|
| 208 |
+
|
| 209 |
+
def reshape(self, s):
|
| 210 |
+
self.qs.reshape(s)
|
| 211 |
+
return self
|
| 212 |
+
|
| 213 |
+
def interpolate(self, ws):
|
| 214 |
+
return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
|
| 215 |
+
|
| 216 |
+
def euler(self, order='xyz'):
|
| 217 |
+
|
| 218 |
+
q = self.normalized().qs
|
| 219 |
+
q0 = q[...,0]
|
| 220 |
+
q1 = q[...,1]
|
| 221 |
+
q2 = q[...,2]
|
| 222 |
+
q3 = q[...,3]
|
| 223 |
+
es = np.zeros(self.shape + (3,))
|
| 224 |
+
|
| 225 |
+
if order == 'xyz':
|
| 226 |
+
es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 227 |
+
es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))
|
| 228 |
+
es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 229 |
+
elif order == 'yzx':
|
| 230 |
+
es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
|
| 231 |
+
es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
|
| 232 |
+
es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))
|
| 233 |
+
else:
|
| 234 |
+
raise NotImplementedError('Cannot convert from ordering %s' % order)
|
| 235 |
+
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
# These conversion don't appear to work correctly for Maya.
|
| 239 |
+
# http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
|
| 240 |
+
|
| 241 |
+
if order == 'xyz':
|
| 242 |
+
es[...,0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
| 243 |
+
es[...,1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
|
| 244 |
+
es[...,2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
| 245 |
+
elif order == 'yzx':
|
| 246 |
+
es[...,0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
|
| 247 |
+
es[...,1] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
|
| 248 |
+
es[...,2] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
| 249 |
+
elif order == 'zxy':
|
| 250 |
+
es[...,0] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
| 251 |
+
es[...,1] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
|
| 252 |
+
es[...,2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
|
| 253 |
+
elif order == 'xzy':
|
| 254 |
+
es[...,0] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
| 255 |
+
es[...,1] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
|
| 256 |
+
es[...,2] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
|
| 257 |
+
elif order == 'yxz':
|
| 258 |
+
es[...,0] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
|
| 259 |
+
es[...,1] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
|
| 260 |
+
es[...,2] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
| 261 |
+
elif order == 'zyx':
|
| 262 |
+
es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
| 263 |
+
es[...,1] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
|
| 264 |
+
es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
| 265 |
+
else:
|
| 266 |
+
raise KeyError('Unknown ordering %s' % order)
|
| 267 |
+
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
# https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
|
| 271 |
+
# Use this class and convert from matrix
|
| 272 |
+
|
| 273 |
+
return es
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def average(self):
|
| 277 |
+
|
| 278 |
+
if len(self.shape) == 1:
|
| 279 |
+
|
| 280 |
+
import numpy.core.umath_tests as ut
|
| 281 |
+
system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)
|
| 282 |
+
w, v = np.linalg.eigh(system)
|
| 283 |
+
qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)
|
| 284 |
+
return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])
|
| 285 |
+
|
| 286 |
+
else:
|
| 287 |
+
|
| 288 |
+
raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')
|
| 289 |
+
|
| 290 |
+
def angle_axis(self):
|
| 291 |
+
|
| 292 |
+
norm = self.normalized()
|
| 293 |
+
s = np.sqrt(1 - (norm.reals**2.0))
|
| 294 |
+
s[s == 0] = 0.001
|
| 295 |
+
|
| 296 |
+
angles = 2.0 * np.arccos(norm.reals)
|
| 297 |
+
axis = norm.imaginaries / s[...,np.newaxis]
|
| 298 |
+
|
| 299 |
+
return angles, axis
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def transforms(self):
|
| 303 |
+
|
| 304 |
+
qw = self.qs[...,0]
|
| 305 |
+
qx = self.qs[...,1]
|
| 306 |
+
qy = self.qs[...,2]
|
| 307 |
+
qz = self.qs[...,3]
|
| 308 |
+
|
| 309 |
+
x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;
|
| 310 |
+
xx = qx * x2; yy = qy * y2; wx = qw * x2;
|
| 311 |
+
xy = qx * y2; yz = qy * z2; wy = qw * y2;
|
| 312 |
+
xz = qx * z2; zz = qz * z2; wz = qw * z2;
|
| 313 |
+
|
| 314 |
+
m = np.empty(self.shape + (3,3))
|
| 315 |
+
m[...,0,0] = 1.0 - (yy + zz)
|
| 316 |
+
m[...,0,1] = xy - wz
|
| 317 |
+
m[...,0,2] = xz + wy
|
| 318 |
+
m[...,1,0] = xy + wz
|
| 319 |
+
m[...,1,1] = 1.0 - (xx + zz)
|
| 320 |
+
m[...,1,2] = yz - wx
|
| 321 |
+
m[...,2,0] = xz - wy
|
| 322 |
+
m[...,2,1] = yz + wx
|
| 323 |
+
m[...,2,2] = 1.0 - (xx + yy)
|
| 324 |
+
|
| 325 |
+
return m
|
| 326 |
+
|
| 327 |
+
def ravel(self):
|
| 328 |
+
return self.qs.ravel()
|
| 329 |
+
|
| 330 |
+
@classmethod
|
| 331 |
+
def id(cls, n):
|
| 332 |
+
|
| 333 |
+
if isinstance(n, tuple):
|
| 334 |
+
qs = np.zeros(n + (4,))
|
| 335 |
+
qs[...,0] = 1.0
|
| 336 |
+
return Quaternions(qs)
|
| 337 |
+
|
| 338 |
+
if isinstance(n, int) or isinstance(n, long):
|
| 339 |
+
qs = np.zeros((n,4))
|
| 340 |
+
qs[:,0] = 1.0
|
| 341 |
+
return Quaternions(qs)
|
| 342 |
+
|
| 343 |
+
raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))
|
| 344 |
+
|
| 345 |
+
@classmethod
|
| 346 |
+
def id_like(cls, a):
|
| 347 |
+
qs = np.zeros(a.shape + (4,))
|
| 348 |
+
qs[...,0] = 1.0
|
| 349 |
+
return Quaternions(qs)
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
def exp(cls, ws):
|
| 353 |
+
|
| 354 |
+
ts = np.sum(ws**2.0, axis=-1)**0.5
|
| 355 |
+
ts[ts == 0] = 0.001
|
| 356 |
+
ls = np.sin(ts) / ts
|
| 357 |
+
|
| 358 |
+
qs = np.empty(ws.shape[:-1] + (4,))
|
| 359 |
+
qs[...,0] = np.cos(ts)
|
| 360 |
+
qs[...,1] = ws[...,0] * ls
|
| 361 |
+
qs[...,2] = ws[...,1] * ls
|
| 362 |
+
qs[...,3] = ws[...,2] * ls
|
| 363 |
+
|
| 364 |
+
return Quaternions(qs).normalized()
|
| 365 |
+
|
| 366 |
+
@classmethod
|
| 367 |
+
def slerp(cls, q0s, q1s, a):
|
| 368 |
+
|
| 369 |
+
fst, snd = cls._broadcast(q0s.qs, q1s.qs)
|
| 370 |
+
fst, a = cls._broadcast(fst, a, scalar=True)
|
| 371 |
+
snd, a = cls._broadcast(snd, a, scalar=True)
|
| 372 |
+
|
| 373 |
+
len = np.sum(fst * snd, axis=-1)
|
| 374 |
+
|
| 375 |
+
neg = len < 0.0
|
| 376 |
+
len[neg] = -len[neg]
|
| 377 |
+
snd[neg] = -snd[neg]
|
| 378 |
+
|
| 379 |
+
amount0 = np.zeros(a.shape)
|
| 380 |
+
amount1 = np.zeros(a.shape)
|
| 381 |
+
|
| 382 |
+
linear = (1.0 - len) < 0.01
|
| 383 |
+
omegas = np.arccos(len[~linear])
|
| 384 |
+
sinoms = np.sin(omegas)
|
| 385 |
+
|
| 386 |
+
amount0[ linear] = 1.0 - a[linear]
|
| 387 |
+
amount1[ linear] = a[linear]
|
| 388 |
+
amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
|
| 389 |
+
amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms
|
| 390 |
+
|
| 391 |
+
return Quaternions(
|
| 392 |
+
amount0[...,np.newaxis] * fst +
|
| 393 |
+
amount1[...,np.newaxis] * snd)
|
| 394 |
+
|
| 395 |
+
@classmethod
|
| 396 |
+
def between(cls, v0s, v1s):
|
| 397 |
+
a = np.cross(v0s, v1s)
|
| 398 |
+
w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
|
| 399 |
+
return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()
|
| 400 |
+
|
| 401 |
+
@classmethod
|
| 402 |
+
def from_angle_axis(cls, angles, axis):
|
| 403 |
+
axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]
|
| 404 |
+
sines = np.sin(angles / 2.0)[...,np.newaxis]
|
| 405 |
+
cosines = np.cos(angles / 2.0)[...,np.newaxis]
|
| 406 |
+
return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
|
| 407 |
+
|
| 408 |
+
@classmethod
|
| 409 |
+
def from_euler(cls, es, order='xyz', world=False):
|
| 410 |
+
|
| 411 |
+
axis = {
|
| 412 |
+
'x' : np.array([1,0,0]),
|
| 413 |
+
'y' : np.array([0,1,0]),
|
| 414 |
+
'z' : np.array([0,0,1]),
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])
|
| 418 |
+
q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])
|
| 419 |
+
q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])
|
| 420 |
+
|
| 421 |
+
return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
|
| 422 |
+
|
| 423 |
+
@classmethod
|
| 424 |
+
def from_transforms(cls, ts):
|
| 425 |
+
|
| 426 |
+
d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]
|
| 427 |
+
|
| 428 |
+
q0 = ( d0 + d1 + d2 + 1.0) / 4.0
|
| 429 |
+
q1 = ( d0 - d1 - d2 + 1.0) / 4.0
|
| 430 |
+
q2 = (-d0 + d1 - d2 + 1.0) / 4.0
|
| 431 |
+
q3 = (-d0 - d1 + d2 + 1.0) / 4.0
|
| 432 |
+
|
| 433 |
+
q0 = np.sqrt(q0.clip(0,None))
|
| 434 |
+
q1 = np.sqrt(q1.clip(0,None))
|
| 435 |
+
q2 = np.sqrt(q2.clip(0,None))
|
| 436 |
+
q3 = np.sqrt(q3.clip(0,None))
|
| 437 |
+
|
| 438 |
+
c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
|
| 439 |
+
c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
|
| 440 |
+
c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
|
| 441 |
+
c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
|
| 442 |
+
|
| 443 |
+
q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])
|
| 444 |
+
q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])
|
| 445 |
+
q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])
|
| 446 |
+
|
| 447 |
+
q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])
|
| 448 |
+
q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])
|
| 449 |
+
q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])
|
| 450 |
+
|
| 451 |
+
q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])
|
| 452 |
+
q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])
|
| 453 |
+
q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])
|
| 454 |
+
|
| 455 |
+
q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])
|
| 456 |
+
q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])
|
| 457 |
+
q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])
|
| 458 |
+
|
| 459 |
+
qs = np.empty(ts.shape[:-2] + (4,))
|
| 460 |
+
qs[...,0] = q0
|
| 461 |
+
qs[...,1] = q1
|
| 462 |
+
qs[...,2] = q2
|
| 463 |
+
qs[...,3] = q3
|
| 464 |
+
|
| 465 |
+
return cls(qs)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
dataloaders/pymo/__init__.py
ADDED
|
File without changes
|
dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc
ADDED
|
Binary file (28.3 kB). View file
|
|
|
dataloaders/pymo/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
dataloaders/pymo/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
dataloaders/pymo/__pycache__/parsers.cpython-312.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc
ADDED
|
Binary file (33.8 kB). View file
|
|
|
dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc
ADDED
|
Binary file (8.27 kB). View file
|
|
|
dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
dataloaders/pymo/data.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Joint():
|
| 4 |
+
def __init__(self, name, parent=None, children=None):
|
| 5 |
+
self.name = name
|
| 6 |
+
self.parent = parent
|
| 7 |
+
self.children = children
|
| 8 |
+
|
| 9 |
+
class MocapData():
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.skeleton = {}
|
| 12 |
+
self.values = None
|
| 13 |
+
self.channel_names = []
|
| 14 |
+
self.framerate = 0.0
|
| 15 |
+
self.root_name = ''
|
| 16 |
+
|
| 17 |
+
def traverse(self, j=None):
|
| 18 |
+
stack = [self.root_name]
|
| 19 |
+
while stack:
|
| 20 |
+
joint = stack.pop()
|
| 21 |
+
yield joint
|
| 22 |
+
for c in self.skeleton[joint]['children']:
|
| 23 |
+
stack.append(c)
|
| 24 |
+
|
| 25 |
+
def clone(self):
|
| 26 |
+
import copy
|
| 27 |
+
new_data = MocapData()
|
| 28 |
+
new_data.skeleton = copy.copy(self.skeleton)
|
| 29 |
+
new_data.values = copy.copy(self.values)
|
| 30 |
+
new_data.channel_names = copy.copy(self.channel_names)
|
| 31 |
+
new_data.root_name = copy.copy(self.root_name)
|
| 32 |
+
new_data.framerate = copy.copy(self.framerate)
|
| 33 |
+
return new_data
|
| 34 |
+
|
| 35 |
+
def get_all_channels(self):
|
| 36 |
+
'''Returns all of the channels parsed from the file as a 2D numpy array'''
|
| 37 |
+
|
| 38 |
+
frames = [f[1] for f in self.values]
|
| 39 |
+
return np.asarray([[channel[2] for channel in frame] for frame in frames])
|
| 40 |
+
|
| 41 |
+
def get_skeleton_tree(self):
|
| 42 |
+
tree = []
|
| 43 |
+
root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0]
|
| 44 |
+
|
| 45 |
+
root_joint = Joint(root_key)
|
| 46 |
+
|
| 47 |
+
def get_empty_channels(self):
|
| 48 |
+
#TODO
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
def get_constant_channels(self):
|
| 52 |
+
#TODO
|
| 53 |
+
pass
|
dataloaders/pymo/features.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
A set of mocap feature extraction functions
|
| 3 |
+
|
| 4 |
+
Created by Omid Alemi | Nov 17 2017
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import peakutils
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
def get_foot_contact_idxs(signal, t=0.02, min_dist=120):
|
| 13 |
+
up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist)
|
| 14 |
+
down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist)
|
| 15 |
+
|
| 16 |
+
return [up_idxs, down_idxs]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120):
|
| 20 |
+
signal = mocap_track.values[col_name].values
|
| 21 |
+
idxs = get_foot_contact_idxs(signal, t, min_dist)
|
| 22 |
+
|
| 23 |
+
step_signal = []
|
| 24 |
+
|
| 25 |
+
c = start
|
| 26 |
+
for f in range(len(signal)):
|
| 27 |
+
if f in idxs[1]:
|
| 28 |
+
c = 0
|
| 29 |
+
elif f in idxs[0]:
|
| 30 |
+
c = 1
|
| 31 |
+
|
| 32 |
+
step_signal.append(c)
|
| 33 |
+
|
| 34 |
+
return step_signal
|
| 35 |
+
|
| 36 |
+
def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120):
|
| 37 |
+
|
| 38 |
+
signal = mocap_track.values[col_name].values
|
| 39 |
+
idxs = get_foot_contact_idxs(signal, t, min_dist)
|
| 40 |
+
|
| 41 |
+
plt.plot(mocap_track.values.index, signal)
|
| 42 |
+
plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro')
|
| 43 |
+
plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go')
|
dataloaders/pymo/parsers.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
BVH Parser Class
|
| 3 |
+
|
| 4 |
+
By Omid Alemi
|
| 5 |
+
Created: June 12, 2017
|
| 6 |
+
|
| 7 |
+
Based on: https://gist.github.com/johnfredcee/2007503
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
import re
|
| 11 |
+
from unicodedata import name
|
| 12 |
+
import numpy as np
|
| 13 |
+
from .data import Joint, MocapData
|
| 14 |
+
|
| 15 |
+
class BVHScanner():
|
| 16 |
+
'''
|
| 17 |
+
A wrapper class for re.Scanner
|
| 18 |
+
'''
|
| 19 |
+
def __init__(self):
|
| 20 |
+
|
| 21 |
+
def identifier(scanner, token):
|
| 22 |
+
return 'IDENT', token
|
| 23 |
+
|
| 24 |
+
def operator(scanner, token):
|
| 25 |
+
return 'OPERATOR', token
|
| 26 |
+
|
| 27 |
+
def digit(scanner, token):
|
| 28 |
+
return 'DIGIT', token
|
| 29 |
+
|
| 30 |
+
def open_brace(scanner, token):
|
| 31 |
+
return 'OPEN_BRACE', token
|
| 32 |
+
|
| 33 |
+
def close_brace(scanner, token):
|
| 34 |
+
return 'CLOSE_BRACE', token
|
| 35 |
+
|
| 36 |
+
self.scanner = re.Scanner([
|
| 37 |
+
(r'[a-zA-Z_]\w*', identifier),
|
| 38 |
+
#(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34
|
| 39 |
+
#(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2
|
| 40 |
+
#(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
|
| 41 |
+
(r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
|
| 42 |
+
(r'}', close_brace),
|
| 43 |
+
(r'}', close_brace),
|
| 44 |
+
(r'{', open_brace),
|
| 45 |
+
(r':', None),
|
| 46 |
+
(r'\s+', None)
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
def scan(self, stuff):
|
| 50 |
+
return self.scanner.scan(stuff)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BVHParser():
|
| 55 |
+
'''
|
| 56 |
+
A class to parse a BVH file.
|
| 57 |
+
|
| 58 |
+
Extracts the skeleton and channel values
|
| 59 |
+
'''
|
| 60 |
+
def __init__(self, filename=None):
|
| 61 |
+
self.reset()
|
| 62 |
+
|
| 63 |
+
def reset(self):
|
| 64 |
+
self._skeleton = {}
|
| 65 |
+
self.bone_context = []
|
| 66 |
+
self._motion_channels = []
|
| 67 |
+
self._motions = []
|
| 68 |
+
self.current_token = 0
|
| 69 |
+
self.framerate = 0.0
|
| 70 |
+
self.root_name = ''
|
| 71 |
+
|
| 72 |
+
self.scanner = BVHScanner()
|
| 73 |
+
|
| 74 |
+
self.data = MocapData()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def parse(self, filename, start=0, stop=-1):
|
| 78 |
+
self.reset()
|
| 79 |
+
self.correct_row_num = 0
|
| 80 |
+
with open(filename, 'r') as f:
|
| 81 |
+
for line in f.readlines():
|
| 82 |
+
self.correct_row_num += 1
|
| 83 |
+
|
| 84 |
+
with open(filename, 'r') as bvh_file:
|
| 85 |
+
raw_contents = bvh_file.read()
|
| 86 |
+
tokens, remainder = self.scanner.scan(raw_contents)
|
| 87 |
+
|
| 88 |
+
self._parse_hierarchy(tokens)
|
| 89 |
+
self.current_token = self.current_token + 1
|
| 90 |
+
self._parse_motion(tokens, start, stop)
|
| 91 |
+
|
| 92 |
+
self.data.skeleton = self._skeleton
|
| 93 |
+
self.data.channel_names = self._motion_channels
|
| 94 |
+
self.data.values = self._to_DataFrame()
|
| 95 |
+
self.data.root_name = self.root_name
|
| 96 |
+
self.data.framerate = self.framerate
|
| 97 |
+
|
| 98 |
+
return self.data
|
| 99 |
+
|
| 100 |
+
def _to_DataFrame(self):
|
| 101 |
+
'''Returns all of the channels parsed from the file as a pandas DataFrame'''
|
| 102 |
+
|
| 103 |
+
import pandas as pd
|
| 104 |
+
time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s')
|
| 105 |
+
frames = [f[1] for f in self._motions]
|
| 106 |
+
channels = np.asarray([[channel[2] for channel in frame] for frame in frames])
|
| 107 |
+
column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels]
|
| 108 |
+
|
| 109 |
+
return pd.DataFrame(data=channels, index=time_index, columns=column_names)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _new_bone(self, parent, name):
|
| 113 |
+
bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []}
|
| 114 |
+
return bone
|
| 115 |
+
|
| 116 |
+
def _push_bone_context(self,name):
|
| 117 |
+
self.bone_context.append(name)
|
| 118 |
+
|
| 119 |
+
def _get_bone_context(self):
|
| 120 |
+
return self.bone_context[len(self.bone_context)-1]
|
| 121 |
+
|
| 122 |
+
def _pop_bone_context(self):
|
| 123 |
+
self.bone_context = self.bone_context[:-1]
|
| 124 |
+
return self.bone_context[len(self.bone_context)-1]
|
| 125 |
+
|
| 126 |
+
def _read_offset(self, bvh, token_index):
|
| 127 |
+
if bvh[token_index] != ('IDENT', 'OFFSET'):
|
| 128 |
+
return None, None
|
| 129 |
+
token_index = token_index + 1
|
| 130 |
+
offsets = [0.0] * 3
|
| 131 |
+
for i in range(3):
|
| 132 |
+
offsets[i] = float(bvh[token_index][1])
|
| 133 |
+
token_index = token_index + 1
|
| 134 |
+
return offsets, token_index
|
| 135 |
+
|
| 136 |
+
def _read_channels(self, bvh, token_index):
|
| 137 |
+
if bvh[token_index] != ('IDENT', 'CHANNELS'):
|
| 138 |
+
return None, None
|
| 139 |
+
token_index = token_index + 1
|
| 140 |
+
channel_count = int(bvh[token_index][1])
|
| 141 |
+
token_index = token_index + 1
|
| 142 |
+
channels = [""] * channel_count
|
| 143 |
+
order = ""
|
| 144 |
+
for i in range(channel_count):
|
| 145 |
+
channels[i] = bvh[token_index][1]
|
| 146 |
+
token_index = token_index + 1
|
| 147 |
+
if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"):
|
| 148 |
+
order += channels[i][0]
|
| 149 |
+
else :
|
| 150 |
+
order = ""
|
| 151 |
+
return channels, token_index, order
|
| 152 |
+
|
| 153 |
+
def _parse_joint(self, bvh, token_index):
|
| 154 |
+
end_site = False
|
| 155 |
+
joint_id = bvh[token_index][1]
|
| 156 |
+
token_index = token_index + 1
|
| 157 |
+
joint_name = bvh[token_index][1]
|
| 158 |
+
token_index = token_index + 1
|
| 159 |
+
|
| 160 |
+
parent_name = self._get_bone_context()
|
| 161 |
+
|
| 162 |
+
if (joint_id == "End"):
|
| 163 |
+
joint_name = parent_name+ '_Nub'
|
| 164 |
+
end_site = True
|
| 165 |
+
joint = self._new_bone(parent_name, joint_name)
|
| 166 |
+
if bvh[token_index][0] != 'OPEN_BRACE':
|
| 167 |
+
print('Was expecting brance, got ', bvh[token_index])
|
| 168 |
+
return None
|
| 169 |
+
token_index = token_index + 1
|
| 170 |
+
offsets, token_index = self._read_offset(bvh, token_index)
|
| 171 |
+
joint['offsets'] = offsets
|
| 172 |
+
if not end_site:
|
| 173 |
+
channels, token_index, order = self._read_channels(bvh, token_index)
|
| 174 |
+
joint['channels'] = channels
|
| 175 |
+
joint['order'] = order
|
| 176 |
+
for channel in channels:
|
| 177 |
+
self._motion_channels.append((joint_name, channel))
|
| 178 |
+
|
| 179 |
+
self._skeleton[joint_name] = joint
|
| 180 |
+
self._skeleton[parent_name]['children'].append(joint_name)
|
| 181 |
+
|
| 182 |
+
while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'):
|
| 183 |
+
self._push_bone_context(joint_name)
|
| 184 |
+
token_index = self._parse_joint(bvh, token_index)
|
| 185 |
+
self._pop_bone_context()
|
| 186 |
+
|
| 187 |
+
if bvh[token_index][0] == 'CLOSE_BRACE':
|
| 188 |
+
return token_index + 1
|
| 189 |
+
|
| 190 |
+
print('Unexpected token ', bvh[token_index])
|
| 191 |
+
|
| 192 |
+
def _parse_hierarchy(self, bvh):
|
| 193 |
+
self.current_token = 0
|
| 194 |
+
if bvh[self.current_token] != ('IDENT', 'HIERARCHY'):
|
| 195 |
+
return None
|
| 196 |
+
self.current_token = self.current_token + 1
|
| 197 |
+
if bvh[self.current_token] != ('IDENT', 'ROOT'):
|
| 198 |
+
return None
|
| 199 |
+
self.current_token = self.current_token + 1
|
| 200 |
+
if bvh[self.current_token][0] != 'IDENT':
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
root_name = bvh[self.current_token][1]
|
| 204 |
+
root_bone = self._new_bone(None, root_name)
|
| 205 |
+
self.current_token = self.current_token + 2 #skipping open brace
|
| 206 |
+
offsets, self.current_token = self._read_offset(bvh, self.current_token)
|
| 207 |
+
channels, self.current_token, order = self._read_channels(bvh, self.current_token)
|
| 208 |
+
root_bone['offsets'] = offsets
|
| 209 |
+
root_bone['channels'] = channels
|
| 210 |
+
root_bone['order'] = order
|
| 211 |
+
self._skeleton[root_name] = root_bone
|
| 212 |
+
self._push_bone_context(root_name)
|
| 213 |
+
|
| 214 |
+
for channel in channels:
|
| 215 |
+
self._motion_channels.append((root_name, channel))
|
| 216 |
+
|
| 217 |
+
while bvh[self.current_token][1] == 'JOINT':
|
| 218 |
+
self.current_token = self._parse_joint(bvh, self.current_token)
|
| 219 |
+
|
| 220 |
+
self.root_name = root_name
|
| 221 |
+
|
| 222 |
+
def _parse_motion(self, bvh, start, stop):
|
| 223 |
+
if bvh[self.current_token][0] != 'IDENT':
|
| 224 |
+
print('Unexpected text')
|
| 225 |
+
return None
|
| 226 |
+
if bvh[self.current_token][1] != 'MOTION':
|
| 227 |
+
print('No motion section')
|
| 228 |
+
return None
|
| 229 |
+
self.current_token = self.current_token + 1
|
| 230 |
+
if bvh[self.current_token][1] != 'Frames':
|
| 231 |
+
return None
|
| 232 |
+
self.current_token = self.current_token + 1
|
| 233 |
+
frame_count = int(bvh[self.current_token][1])
|
| 234 |
+
|
| 235 |
+
if stop<0 or stop>frame_count:
|
| 236 |
+
stop = min(frame_count, self.correct_row_num-431)
|
| 237 |
+
|
| 238 |
+
assert(start>=0)
|
| 239 |
+
assert(start<stop)
|
| 240 |
+
|
| 241 |
+
self.current_token = self.current_token + 1
|
| 242 |
+
if bvh[self.current_token][1] != 'Frame':
|
| 243 |
+
return None
|
| 244 |
+
self.current_token = self.current_token + 1
|
| 245 |
+
if bvh[self.current_token][1] != 'Time':
|
| 246 |
+
return None
|
| 247 |
+
self.current_token = self.current_token + 1
|
| 248 |
+
frame_rate = float(bvh[self.current_token][1])
|
| 249 |
+
|
| 250 |
+
self.framerate = frame_rate
|
| 251 |
+
|
| 252 |
+
self.current_token = self.current_token + 1
|
| 253 |
+
|
| 254 |
+
frame_time = 0.0
|
| 255 |
+
self._motions = [()] * (stop-start)
|
| 256 |
+
idx=0
|
| 257 |
+
for i in range(stop):
|
| 258 |
+
#print(i)
|
| 259 |
+
channel_values = []
|
| 260 |
+
|
| 261 |
+
for channel in self._motion_channels:
|
| 262 |
+
#print(channel)
|
| 263 |
+
channel_values.append((channel[0], channel[1], float(bvh[self.current_token][1])))
|
| 264 |
+
self.current_token = self.current_token + 1
|
| 265 |
+
|
| 266 |
+
if i>=start:
|
| 267 |
+
self._motions[idx] = (frame_time, channel_values)
|
| 268 |
+
frame_time = frame_time + frame_rate
|
| 269 |
+
idx+=1
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
p = BVHParser()
|
| 274 |
+
data = [p.parse("../../../datasets/beat_full/2/2_scott_0_1_1.bvh")]
|
dataloaders/pymo/preprocessing.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Preprocessing Tranformers Based on sci-kit's API
|
| 3 |
+
|
| 4 |
+
By Omid Alemi
|
| 5 |
+
Created on June 12, 2017
|
| 6 |
+
'''
|
| 7 |
+
import copy
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
| 11 |
+
from .Quaternions import Quaternions
|
| 12 |
+
from .rotation_tools import Rotation
|
| 13 |
+
|
| 14 |
+
class MocapParameterizer(BaseEstimator, TransformerMixin):
|
| 15 |
+
def __init__(self, param_type = 'euler'):
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
param_type = {'euler', 'quat', 'expmap', 'position'}
|
| 19 |
+
'''
|
| 20 |
+
self.param_type = param_type
|
| 21 |
+
|
| 22 |
+
def fit(self, X, y=None):
|
| 23 |
+
return self
|
| 24 |
+
|
| 25 |
+
def transform(self, X, y=None):
|
| 26 |
+
if self.param_type == 'euler':
|
| 27 |
+
return X
|
| 28 |
+
elif self.param_type == 'expmap':
|
| 29 |
+
return self._to_expmap(X)
|
| 30 |
+
elif self.param_type == 'quat':
|
| 31 |
+
return X
|
| 32 |
+
elif self.param_type == 'position':
|
| 33 |
+
return self._to_pos(X)
|
| 34 |
+
else:
|
| 35 |
+
raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type)
|
| 36 |
+
# return X
|
| 37 |
+
|
| 38 |
+
def inverse_transform(self, X, copy=None):
|
| 39 |
+
if self.param_type == 'euler':
|
| 40 |
+
return X
|
| 41 |
+
elif self.param_type == 'expmap':
|
| 42 |
+
return self._expmap_to_euler(X)
|
| 43 |
+
elif self.param_type == 'quat':
|
| 44 |
+
raise UnsupportedParamError('quat2euler is not supported')
|
| 45 |
+
elif self.param_type == 'position':
|
| 46 |
+
print('positions 2 eulers is not supported')
|
| 47 |
+
return X
|
| 48 |
+
else:
|
| 49 |
+
raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type)
|
| 50 |
+
|
| 51 |
+
def _to_pos(self, X):
|
| 52 |
+
'''Converts joints rotations in Euler angles to joint positions'''
|
| 53 |
+
|
| 54 |
+
Q = []
|
| 55 |
+
for track in X:
|
| 56 |
+
channels = []
|
| 57 |
+
titles = []
|
| 58 |
+
euler_df = track.values
|
| 59 |
+
|
| 60 |
+
# Create a new DataFrame to store the exponential map rep
|
| 61 |
+
pos_df = pd.DataFrame(index=euler_df.index)
|
| 62 |
+
|
| 63 |
+
# Copy the root rotations into the new DataFrame
|
| 64 |
+
# rxp = '%s_Xrotation'%track.root_name
|
| 65 |
+
# ryp = '%s_Yrotation'%track.root_name
|
| 66 |
+
# rzp = '%s_Zrotation'%track.root_name
|
| 67 |
+
# pos_df[rxp] = pd.Series(data=euler_df[rxp], index=pos_df.index)
|
| 68 |
+
# pos_df[ryp] = pd.Series(data=euler_df[ryp], index=pos_df.index)
|
| 69 |
+
# pos_df[rzp] = pd.Series(data=euler_df[rzp], index=pos_df.index)
|
| 70 |
+
|
| 71 |
+
# List the columns that contain rotation channels
|
| 72 |
+
rot_cols = [c for c in euler_df.columns if ('rotation' in c)]
|
| 73 |
+
|
| 74 |
+
# List the columns that contain position channels
|
| 75 |
+
pos_cols = [c for c in euler_df.columns if ('position' in c)]
|
| 76 |
+
|
| 77 |
+
# List the joints that are not end sites, i.e., have channels
|
| 78 |
+
joints = (joint for joint in track.skeleton)
|
| 79 |
+
|
| 80 |
+
tree_data = {}
|
| 81 |
+
|
| 82 |
+
for joint in track.traverse():
|
| 83 |
+
parent = track.skeleton[joint]['parent']
|
| 84 |
+
rot_order = track.skeleton[joint]['order']
|
| 85 |
+
#print("rot_order:" + joint + " :" + rot_order)
|
| 86 |
+
|
| 87 |
+
# Get the rotation columns that belong to this joint
|
| 88 |
+
rc = euler_df[[c for c in rot_cols if joint in c]]
|
| 89 |
+
|
| 90 |
+
# Get the position columns that belong to this joint
|
| 91 |
+
pc = euler_df[[c for c in pos_cols if joint in c]]
|
| 92 |
+
|
| 93 |
+
# Make sure the columns are organized in xyz order
|
| 94 |
+
if rc.shape[1] < 3:
|
| 95 |
+
euler_values = np.zeros((euler_df.shape[0], 3))
|
| 96 |
+
rot_order = "XYZ"
|
| 97 |
+
else:
|
| 98 |
+
euler_values = np.pi/180.0*np.transpose(np.array([track.values['%s_%srotation'%(joint, rot_order[0])], track.values['%s_%srotation'%(joint, rot_order[1])], track.values['%s_%srotation'%(joint, rot_order[2])]]))
|
| 99 |
+
|
| 100 |
+
if pc.shape[1] < 3:
|
| 101 |
+
pos_values = np.asarray([[0,0,0] for f in pc.iterrows()])
|
| 102 |
+
else:
|
| 103 |
+
pos_values =np.asarray([[f[1]['%s_Xposition'%joint],
|
| 104 |
+
f[1]['%s_Yposition'%joint],
|
| 105 |
+
f[1]['%s_Zposition'%joint]] for f in pc.iterrows()])
|
| 106 |
+
|
| 107 |
+
quats = Quaternions.from_euler(np.asarray(euler_values), order=rot_order.lower(), world=False)
|
| 108 |
+
|
| 109 |
+
tree_data[joint]=[
|
| 110 |
+
[], # to store the rotation matrix
|
| 111 |
+
[] # to store the calculated position
|
| 112 |
+
]
|
| 113 |
+
if track.root_name == joint:
|
| 114 |
+
tree_data[joint][0] = quats#rotmats
|
| 115 |
+
# tree_data[joint][1] = np.add(pos_values, track.skeleton[joint]['offsets'])
|
| 116 |
+
tree_data[joint][1] = pos_values
|
| 117 |
+
else:
|
| 118 |
+
# for every frame i, multiply this joint's rotmat to the rotmat of its parent
|
| 119 |
+
tree_data[joint][0] = tree_data[parent][0]*quats# np.matmul(rotmats, tree_data[parent][0])
|
| 120 |
+
|
| 121 |
+
# add the position channel to the offset and store it in k, for every frame i
|
| 122 |
+
k = pos_values + np.asarray(track.skeleton[joint]['offsets'])
|
| 123 |
+
|
| 124 |
+
# multiply k to the rotmat of the parent for every frame i
|
| 125 |
+
q = tree_data[parent][0]*k #np.matmul(k.reshape(k.shape[0],1,3), tree_data[parent][0])
|
| 126 |
+
|
| 127 |
+
# add q to the position of the parent, for every frame i
|
| 128 |
+
tree_data[joint][1] = tree_data[parent][1] + q #q.reshape(k.shape[0],3) + tree_data[parent][1]
|
| 129 |
+
|
| 130 |
+
# Create the corresponding columns in the new DataFrame
|
| 131 |
+
pos_df['%s_Xposition'%joint] = pd.Series(data=[e[0] for e in tree_data[joint][1]], index=pos_df.index)
|
| 132 |
+
pos_df['%s_Yposition'%joint] = pd.Series(data=[e[1] for e in tree_data[joint][1]], index=pos_df.index)
|
| 133 |
+
pos_df['%s_Zposition'%joint] = pd.Series(data=[e[2] for e in tree_data[joint][1]], index=pos_df.index)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
new_track = track.clone()
|
| 137 |
+
new_track.values = pos_df
|
| 138 |
+
Q.append(new_track)
|
| 139 |
+
return Q
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _to_expmap(self, X):
|
| 143 |
+
'''Converts Euler angles to Exponential Maps'''
|
| 144 |
+
|
| 145 |
+
Q = []
|
| 146 |
+
for track in X:
|
| 147 |
+
channels = []
|
| 148 |
+
titles = []
|
| 149 |
+
euler_df = track.values
|
| 150 |
+
|
| 151 |
+
# Create a new DataFrame to store the exponential map rep
|
| 152 |
+
exp_df = pd.DataFrame(index=euler_df.index)
|
| 153 |
+
|
| 154 |
+
# Copy the root positions into the new DataFrame
|
| 155 |
+
rxp = '%s_Xposition'%track.root_name
|
| 156 |
+
ryp = '%s_Yposition'%track.root_name
|
| 157 |
+
rzp = '%s_Zposition'%track.root_name
|
| 158 |
+
exp_df[rxp] = pd.Series(data=euler_df[rxp], index=exp_df.index)
|
| 159 |
+
exp_df[ryp] = pd.Series(data=euler_df[ryp], index=exp_df.index)
|
| 160 |
+
exp_df[rzp] = pd.Series(data=euler_df[rzp], index=exp_df.index)
|
| 161 |
+
|
| 162 |
+
# List the columns that contain rotation channels
|
| 163 |
+
rots = [c for c in euler_df.columns if ('rotation' in c and 'Nub' not in c)]
|
| 164 |
+
|
| 165 |
+
# List the joints that are not end sites, i.e., have channels
|
| 166 |
+
joints = (joint for joint in track.skeleton if 'Nub' not in joint)
|
| 167 |
+
|
| 168 |
+
for joint in joints:
|
| 169 |
+
r = euler_df[[c for c in rots if joint in c]] # Get the columns that belong to this joint
|
| 170 |
+
euler = [[f[1]['%s_Xrotation'%joint], f[1]['%s_Yrotation'%joint], f[1]['%s_Zrotation'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order
|
| 171 |
+
exps = [Rotation(f, 'euler', from_deg=True).to_expmap() for f in euler] # Convert the eulers to exp maps
|
| 172 |
+
|
| 173 |
+
# Create the corresponding columns in the new DataFrame
|
| 174 |
+
|
| 175 |
+
exp_df['%s_alpha'%joint] = pd.Series(data=[e[0] for e in exps], index=exp_df.index)
|
| 176 |
+
exp_df['%s_beta'%joint] = pd.Series(data=[e[1] for e in exps], index=exp_df.index)
|
| 177 |
+
exp_df['%s_gamma'%joint] = pd.Series(data=[e[2] for e in exps], index=exp_df.index)
|
| 178 |
+
|
| 179 |
+
new_track = track.clone()
|
| 180 |
+
new_track.values = exp_df
|
| 181 |
+
Q.append(new_track)
|
| 182 |
+
|
| 183 |
+
return Q
|
| 184 |
+
|
| 185 |
+
def _expmap_to_euler(self, X):
|
| 186 |
+
Q = []
|
| 187 |
+
for track in X:
|
| 188 |
+
channels = []
|
| 189 |
+
titles = []
|
| 190 |
+
exp_df = track.values
|
| 191 |
+
|
| 192 |
+
# Create a new DataFrame to store the exponential map rep
|
| 193 |
+
euler_df = pd.DataFrame(index=exp_df.index)
|
| 194 |
+
|
| 195 |
+
# Copy the root positions into the new DataFrame
|
| 196 |
+
rxp = '%s_Xposition'%track.root_name
|
| 197 |
+
ryp = '%s_Yposition'%track.root_name
|
| 198 |
+
rzp = '%s_Zposition'%track.root_name
|
| 199 |
+
euler_df[rxp] = pd.Series(data=exp_df[rxp], index=euler_df.index)
|
| 200 |
+
euler_df[ryp] = pd.Series(data=exp_df[ryp], index=euler_df.index)
|
| 201 |
+
euler_df[rzp] = pd.Series(data=exp_df[rzp], index=euler_df.index)
|
| 202 |
+
|
| 203 |
+
# List the columns that contain rotation channels
|
| 204 |
+
exp_params = [c for c in exp_df.columns if ( any(p in c for p in ['alpha', 'beta','gamma']) and 'Nub' not in c)]
|
| 205 |
+
|
| 206 |
+
# List the joints that are not end sites, i.e., have channels
|
| 207 |
+
joints = (joint for joint in track.skeleton if 'Nub' not in joint)
|
| 208 |
+
|
| 209 |
+
for joint in joints:
|
| 210 |
+
r = exp_df[[c for c in exp_params if joint in c]] # Get the columns that belong to this joint
|
| 211 |
+
expmap = [[f[1]['%s_alpha'%joint], f[1]['%s_beta'%joint], f[1]['%s_gamma'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order
|
| 212 |
+
euler_rots = [Rotation(f, 'expmap').to_euler(True)[0] for f in expmap] # Convert the eulers to exp maps
|
| 213 |
+
|
| 214 |
+
# Create the corresponding columns in the new DataFrame
|
| 215 |
+
|
| 216 |
+
euler_df['%s_Xrotation'%joint] = pd.Series(data=[e[0] for e in euler_rots], index=euler_df.index)
|
| 217 |
+
euler_df['%s_Yrotation'%joint] = pd.Series(data=[e[1] for e in euler_rots], index=euler_df.index)
|
| 218 |
+
euler_df['%s_Zrotation'%joint] = pd.Series(data=[e[2] for e in euler_rots], index=euler_df.index)
|
| 219 |
+
|
| 220 |
+
new_track = track.clone()
|
| 221 |
+
new_track.values = euler_df
|
| 222 |
+
Q.append(new_track)
|
| 223 |
+
|
| 224 |
+
return Q
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class JointSelector(BaseEstimator, TransformerMixin):
|
| 228 |
+
'''
|
| 229 |
+
Allows for filtering the mocap data to include only the selected joints
|
| 230 |
+
'''
|
| 231 |
+
def __init__(self, joints, include_root=False):
|
| 232 |
+
self.joints = joints
|
| 233 |
+
self.include_root = include_root
|
| 234 |
+
|
| 235 |
+
def fit(self, X, y=None):
|
| 236 |
+
return self
|
| 237 |
+
|
| 238 |
+
def transform(self, X, y=None):
|
| 239 |
+
selected_joints = []
|
| 240 |
+
selected_channels = []
|
| 241 |
+
|
| 242 |
+
if self.include_root:
|
| 243 |
+
selected_joints.append(X[0].root_name)
|
| 244 |
+
|
| 245 |
+
selected_joints.extend(self.joints)
|
| 246 |
+
|
| 247 |
+
for joint_name in selected_joints:
|
| 248 |
+
selected_channels.extend([o for o in X[0].values.columns if joint_name in o])
|
| 249 |
+
|
| 250 |
+
Q = []
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
for track in X:
|
| 254 |
+
t2 = track.clone()
|
| 255 |
+
|
| 256 |
+
for key in track.skeleton.keys():
|
| 257 |
+
if key not in selected_joints:
|
| 258 |
+
t2.skeleton.pop(key)
|
| 259 |
+
t2.values = track.values[selected_channels]
|
| 260 |
+
|
| 261 |
+
Q.append(t2)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
return Q
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class Numpyfier(BaseEstimator, TransformerMixin):
|
| 268 |
+
'''
|
| 269 |
+
Just converts the values in a MocapData object into a numpy array
|
| 270 |
+
Useful for the final stage of a pipeline before training
|
| 271 |
+
'''
|
| 272 |
+
def __init__(self):
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
def fit(self, X, y=None):
|
| 276 |
+
self.org_mocap_ = X[0].clone()
|
| 277 |
+
self.org_mocap_.values.drop(self.org_mocap_.values.index, inplace=True)
|
| 278 |
+
|
| 279 |
+
return self
|
| 280 |
+
|
| 281 |
+
def transform(self, X, y=None):
|
| 282 |
+
Q = []
|
| 283 |
+
|
| 284 |
+
for track in X:
|
| 285 |
+
Q.append(track.values.values)
|
| 286 |
+
|
| 287 |
+
return np.array(Q)
|
| 288 |
+
|
| 289 |
+
def inverse_transform(self, X, copy=None):
|
| 290 |
+
Q = []
|
| 291 |
+
|
| 292 |
+
for track in X:
|
| 293 |
+
|
| 294 |
+
new_mocap = self.org_mocap_.clone()
|
| 295 |
+
time_index = pd.to_timedelta([f for f in range(track.shape[0])], unit='s')
|
| 296 |
+
|
| 297 |
+
new_df = pd.DataFrame(data=track, index=time_index, columns=self.org_mocap_.values.columns)
|
| 298 |
+
|
| 299 |
+
new_mocap.values = new_df
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
Q.append(new_mocap)
|
| 303 |
+
|
| 304 |
+
return Q
|
| 305 |
+
|
| 306 |
+
class RootTransformer(BaseEstimator, TransformerMixin):
|
| 307 |
+
def __init__(self, method):
|
| 308 |
+
"""
|
| 309 |
+
Accepted methods:
|
| 310 |
+
abdolute_translation_deltas
|
| 311 |
+
pos_rot_deltas
|
| 312 |
+
"""
|
| 313 |
+
self.method = method
|
| 314 |
+
|
| 315 |
+
def fit(self, X, y=None):
|
| 316 |
+
return self
|
| 317 |
+
|
| 318 |
+
def transform(self, X, y=None):
|
| 319 |
+
Q = []
|
| 320 |
+
|
| 321 |
+
for track in X:
|
| 322 |
+
if self.method == 'abdolute_translation_deltas':
|
| 323 |
+
new_df = track.values.copy()
|
| 324 |
+
xpcol = '%s_Xposition'%track.root_name
|
| 325 |
+
ypcol = '%s_Yposition'%track.root_name
|
| 326 |
+
zpcol = '%s_Zposition'%track.root_name
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
dxpcol = '%s_dXposition'%track.root_name
|
| 330 |
+
dzpcol = '%s_dZposition'%track.root_name
|
| 331 |
+
|
| 332 |
+
dx = track.values[xpcol].diff()
|
| 333 |
+
dz = track.values[zpcol].diff()
|
| 334 |
+
|
| 335 |
+
dx[0] = 0
|
| 336 |
+
dz[0] = 0
|
| 337 |
+
|
| 338 |
+
new_df.drop([xpcol, zpcol], axis=1, inplace=True)
|
| 339 |
+
|
| 340 |
+
new_df[dxpcol] = dx
|
| 341 |
+
new_df[dzpcol] = dz
|
| 342 |
+
|
| 343 |
+
new_track = track.clone()
|
| 344 |
+
new_track.values = new_df
|
| 345 |
+
# end of abdolute_translation_deltas
|
| 346 |
+
|
| 347 |
+
elif self.method == 'pos_rot_deltas':
|
| 348 |
+
new_track = track.clone()
|
| 349 |
+
|
| 350 |
+
# Absolute columns
|
| 351 |
+
xp_col = '%s_Xposition'%track.root_name
|
| 352 |
+
yp_col = '%s_Yposition'%track.root_name
|
| 353 |
+
zp_col = '%s_Zposition'%track.root_name
|
| 354 |
+
|
| 355 |
+
xr_col = '%s_Xrotation'%track.root_name
|
| 356 |
+
yr_col = '%s_Yrotation'%track.root_name
|
| 357 |
+
zr_col = '%s_Zrotation'%track.root_name
|
| 358 |
+
|
| 359 |
+
# Delta columns
|
| 360 |
+
dxp_col = '%s_dXposition'%track.root_name
|
| 361 |
+
dzp_col = '%s_dZposition'%track.root_name
|
| 362 |
+
|
| 363 |
+
dxr_col = '%s_dXrotation'%track.root_name
|
| 364 |
+
dyr_col = '%s_dYrotation'%track.root_name
|
| 365 |
+
dzr_col = '%s_dZrotation'%track.root_name
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
new_df = track.values.copy()
|
| 369 |
+
|
| 370 |
+
root_pos_x_diff = pd.Series(data=track.values[xp_col].diff(), index=new_df.index)
|
| 371 |
+
root_pos_z_diff = pd.Series(data=track.values[zp_col].diff(), index=new_df.index)
|
| 372 |
+
|
| 373 |
+
root_rot_y_diff = pd.Series(data=track.values[yr_col].diff(), index=new_df.index)
|
| 374 |
+
root_rot_x_diff = pd.Series(data=track.values[xr_col].diff(), index=new_df.index)
|
| 375 |
+
root_rot_z_diff = pd.Series(data=track.values[zr_col].diff(), index=new_df.index)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
root_pos_x_diff[0] = 0
|
| 379 |
+
root_pos_z_diff[0] = 0
|
| 380 |
+
|
| 381 |
+
root_rot_y_diff[0] = 0
|
| 382 |
+
root_rot_x_diff[0] = 0
|
| 383 |
+
root_rot_z_diff[0] = 0
|
| 384 |
+
|
| 385 |
+
new_df.drop([xr_col, yr_col, zr_col, xp_col, zp_col], axis=1, inplace=True)
|
| 386 |
+
|
| 387 |
+
new_df[dxp_col] = root_pos_x_diff
|
| 388 |
+
new_df[dzp_col] = root_pos_z_diff
|
| 389 |
+
|
| 390 |
+
new_df[dxr_col] = root_rot_x_diff
|
| 391 |
+
new_df[dyr_col] = root_rot_y_diff
|
| 392 |
+
new_df[dzr_col] = root_rot_z_diff
|
| 393 |
+
|
| 394 |
+
new_track.values = new_df
|
| 395 |
+
|
| 396 |
+
Q.append(new_track)
|
| 397 |
+
|
| 398 |
+
return Q
|
| 399 |
+
|
| 400 |
+
def inverse_transform(self, X, copy=None, start_pos=None):
|
| 401 |
+
Q = []
|
| 402 |
+
|
| 403 |
+
#TODO: simplify this implementation
|
| 404 |
+
|
| 405 |
+
startx = 0
|
| 406 |
+
startz = 0
|
| 407 |
+
|
| 408 |
+
if start_pos is not None:
|
| 409 |
+
startx, startz = start_pos
|
| 410 |
+
|
| 411 |
+
for track in X:
|
| 412 |
+
new_track = track.clone()
|
| 413 |
+
if self.method == 'abdolute_translation_deltas':
|
| 414 |
+
new_df = new_track.values
|
| 415 |
+
xpcol = '%s_Xposition'%track.root_name
|
| 416 |
+
ypcol = '%s_Yposition'%track.root_name
|
| 417 |
+
zpcol = '%s_Zposition'%track.root_name
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
dxpcol = '%s_dXposition'%track.root_name
|
| 421 |
+
dzpcol = '%s_dZposition'%track.root_name
|
| 422 |
+
|
| 423 |
+
dx = track.values[dxpcol].values
|
| 424 |
+
dz = track.values[dzpcol].values
|
| 425 |
+
|
| 426 |
+
recx = [startx]
|
| 427 |
+
recz = [startz]
|
| 428 |
+
|
| 429 |
+
for i in range(dx.shape[0]-1):
|
| 430 |
+
recx.append(recx[i]+dx[i+1])
|
| 431 |
+
recz.append(recz[i]+dz[i+1])
|
| 432 |
+
|
| 433 |
+
# recx = [recx[i]+dx[i+1] for i in range(dx.shape[0]-1)]
|
| 434 |
+
# recz = [recz[i]+dz[i+1] for i in range(dz.shape[0]-1)]
|
| 435 |
+
# recx = dx[:-1] + dx[1:]
|
| 436 |
+
# recz = dz[:-1] + dz[1:]
|
| 437 |
+
|
| 438 |
+
new_df[xpcol] = pd.Series(data=recx, index=new_df.index)
|
| 439 |
+
new_df[zpcol] = pd.Series(data=recz, index=new_df.index)
|
| 440 |
+
|
| 441 |
+
new_df.drop([dxpcol, dzpcol], axis=1, inplace=True)
|
| 442 |
+
|
| 443 |
+
new_track.values = new_df
|
| 444 |
+
# end of abdolute_translation_deltas
|
| 445 |
+
|
| 446 |
+
elif self.method == 'pos_rot_deltas':
|
| 447 |
+
new_track = track.clone()
|
| 448 |
+
|
| 449 |
+
# Absolute columns
|
| 450 |
+
xp_col = '%s_Xposition'%track.root_name
|
| 451 |
+
yp_col = '%s_Yposition'%track.root_name
|
| 452 |
+
zp_col = '%s_Zposition'%track.root_name
|
| 453 |
+
|
| 454 |
+
xr_col = '%s_Xrotation'%track.root_name
|
| 455 |
+
yr_col = '%s_Yrotation'%track.root_name
|
| 456 |
+
zr_col = '%s_Zrotation'%track.root_name
|
| 457 |
+
|
| 458 |
+
# Delta columns
|
| 459 |
+
dxp_col = '%s_dXposition'%track.root_name
|
| 460 |
+
dzp_col = '%s_dZposition'%track.root_name
|
| 461 |
+
|
| 462 |
+
dxr_col = '%s_dXrotation'%track.root_name
|
| 463 |
+
dyr_col = '%s_dYrotation'%track.root_name
|
| 464 |
+
dzr_col = '%s_dZrotation'%track.root_name
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
new_df = track.values.copy()
|
| 468 |
+
|
| 469 |
+
dx = track.values[dxp_col].values
|
| 470 |
+
dz = track.values[dzp_col].values
|
| 471 |
+
|
| 472 |
+
drx = track.values[dxr_col].values
|
| 473 |
+
dry = track.values[dyr_col].values
|
| 474 |
+
drz = track.values[dzr_col].values
|
| 475 |
+
|
| 476 |
+
rec_xp = [startx]
|
| 477 |
+
rec_zp = [startz]
|
| 478 |
+
|
| 479 |
+
rec_xr = [0]
|
| 480 |
+
rec_yr = [0]
|
| 481 |
+
rec_zr = [0]
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
for i in range(dx.shape[0]-1):
|
| 485 |
+
rec_xp.append(rec_xp[i]+dx[i+1])
|
| 486 |
+
rec_zp.append(rec_zp[i]+dz[i+1])
|
| 487 |
+
|
| 488 |
+
rec_xr.append(rec_xr[i]+drx[i+1])
|
| 489 |
+
rec_yr.append(rec_yr[i]+dry[i+1])
|
| 490 |
+
rec_zr.append(rec_zr[i]+drz[i+1])
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
new_df[xp_col] = pd.Series(data=rec_xp, index=new_df.index)
|
| 494 |
+
new_df[zp_col] = pd.Series(data=rec_zp, index=new_df.index)
|
| 495 |
+
|
| 496 |
+
new_df[xr_col] = pd.Series(data=rec_xr, index=new_df.index)
|
| 497 |
+
new_df[yr_col] = pd.Series(data=rec_yr, index=new_df.index)
|
| 498 |
+
new_df[zr_col] = pd.Series(data=rec_zr, index=new_df.index)
|
| 499 |
+
|
| 500 |
+
new_df.drop([dxr_col, dyr_col, dzr_col, dxp_col, dzp_col], axis=1, inplace=True)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
new_track.values = new_df
|
| 504 |
+
|
| 505 |
+
Q.append(new_track)
|
| 506 |
+
|
| 507 |
+
return Q
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class RootCentricPositionNormalizer(BaseEstimator, TransformerMixin):
|
| 511 |
+
def __init__(self):
|
| 512 |
+
pass
|
| 513 |
+
|
| 514 |
+
def fit(self, X, y=None):
|
| 515 |
+
return self
|
| 516 |
+
|
| 517 |
+
def transform(self, X, y=None):
|
| 518 |
+
Q = []
|
| 519 |
+
|
| 520 |
+
for track in X:
|
| 521 |
+
new_track = track.clone()
|
| 522 |
+
|
| 523 |
+
rxp = '%s_Xposition'%track.root_name
|
| 524 |
+
ryp = '%s_Yposition'%track.root_name
|
| 525 |
+
rzp = '%s_Zposition'%track.root_name
|
| 526 |
+
|
| 527 |
+
projected_root_pos = track.values[[rxp, ryp, rzp]]
|
| 528 |
+
|
| 529 |
+
projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref
|
| 530 |
+
|
| 531 |
+
new_df = pd.DataFrame(index=track.values.index)
|
| 532 |
+
|
| 533 |
+
all_but_root = [joint for joint in track.skeleton if track.root_name not in joint]
|
| 534 |
+
# all_but_root = [joint for joint in track.skeleton]
|
| 535 |
+
for joint in all_but_root:
|
| 536 |
+
new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]-projected_root_pos[rxp], index=new_df.index)
|
| 537 |
+
new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]-projected_root_pos[ryp], index=new_df.index)
|
| 538 |
+
new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]-projected_root_pos[rzp], index=new_df.index)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# keep the root as it is now
|
| 542 |
+
new_df[rxp] = track.values[rxp]
|
| 543 |
+
new_df[ryp] = track.values[ryp]
|
| 544 |
+
new_df[rzp] = track.values[rzp]
|
| 545 |
+
|
| 546 |
+
new_track.values = new_df
|
| 547 |
+
|
| 548 |
+
Q.append(new_track)
|
| 549 |
+
|
| 550 |
+
return Q
|
| 551 |
+
|
| 552 |
+
def inverse_transform(self, X, copy=None):
|
| 553 |
+
Q = []
|
| 554 |
+
|
| 555 |
+
for track in X:
|
| 556 |
+
new_track = track.clone()
|
| 557 |
+
|
| 558 |
+
rxp = '%s_Xposition'%track.root_name
|
| 559 |
+
ryp = '%s_Yposition'%track.root_name
|
| 560 |
+
rzp = '%s_Zposition'%track.root_name
|
| 561 |
+
|
| 562 |
+
projected_root_pos = track.values[[rxp, ryp, rzp]]
|
| 563 |
+
|
| 564 |
+
projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref
|
| 565 |
+
|
| 566 |
+
new_df = pd.DataFrame(index=track.values.index)
|
| 567 |
+
|
| 568 |
+
for joint in track.skeleton:
|
| 569 |
+
new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]+projected_root_pos[rxp], index=new_df.index)
|
| 570 |
+
new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]+projected_root_pos[ryp], index=new_df.index)
|
| 571 |
+
new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]+projected_root_pos[rzp], index=new_df.index)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
new_track.values = new_df
|
| 575 |
+
|
| 576 |
+
Q.append(new_track)
|
| 577 |
+
|
| 578 |
+
return Q
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class Flattener(BaseEstimator, TransformerMixin):
|
| 582 |
+
def __init__(self):
|
| 583 |
+
pass
|
| 584 |
+
|
| 585 |
+
def fit(self, X, y=None):
|
| 586 |
+
return self
|
| 587 |
+
|
| 588 |
+
def transform(self, X, y=None):
|
| 589 |
+
return np.concatenate(X, axis=0)
|
| 590 |
+
|
| 591 |
+
class ConstantsRemover(BaseEstimator, TransformerMixin):
|
| 592 |
+
'''
|
| 593 |
+
For now it just looks at the first track
|
| 594 |
+
'''
|
| 595 |
+
|
| 596 |
+
def __init__(self, eps = 10e-10):
|
| 597 |
+
self.eps = eps
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def fit(self, X, y=None):
|
| 601 |
+
stds = X[0].values.std()
|
| 602 |
+
cols = X[0].values.columns.values
|
| 603 |
+
self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()]
|
| 604 |
+
self.const_values_ = {c:X[0].values[c].values[0] for c in cols if (stds[c] < self.eps).any()}
|
| 605 |
+
return self
|
| 606 |
+
|
| 607 |
+
def transform(self, X, y=None):
|
| 608 |
+
Q = []
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
for track in X:
|
| 612 |
+
t2 = track.clone()
|
| 613 |
+
#for key in t2.skeleton.keys():
|
| 614 |
+
# if key in self.ConstDims_:
|
| 615 |
+
# t2.skeleton.pop(key)
|
| 616 |
+
t2.values = track.values[track.values.columns.difference(self.const_dims_)]
|
| 617 |
+
Q.append(t2)
|
| 618 |
+
|
| 619 |
+
return Q
|
| 620 |
+
|
| 621 |
+
def inverse_transform(self, X, copy=None):
|
| 622 |
+
Q = []
|
| 623 |
+
|
| 624 |
+
for track in X:
|
| 625 |
+
t2 = track.clone()
|
| 626 |
+
for d in self.const_dims_:
|
| 627 |
+
t2.values[d] = self.const_values_[d]
|
| 628 |
+
Q.append(t2)
|
| 629 |
+
|
| 630 |
+
return Q
|
| 631 |
+
|
| 632 |
+
class ListStandardScaler(BaseEstimator, TransformerMixin):
|
| 633 |
+
def __init__(self, is_DataFrame=False):
|
| 634 |
+
self.is_DataFrame = is_DataFrame
|
| 635 |
+
|
| 636 |
+
def fit(self, X, y=None):
|
| 637 |
+
if self.is_DataFrame:
|
| 638 |
+
X_train_flat = np.concatenate([m.values for m in X], axis=0)
|
| 639 |
+
else:
|
| 640 |
+
X_train_flat = np.concatenate([m for m in X], axis=0)
|
| 641 |
+
|
| 642 |
+
self.data_mean_ = np.mean(X_train_flat, axis=0)
|
| 643 |
+
self.data_std_ = np.std(X_train_flat, axis=0)
|
| 644 |
+
|
| 645 |
+
return self
|
| 646 |
+
|
| 647 |
+
def transform(self, X, y=None):
|
| 648 |
+
Q = []
|
| 649 |
+
|
| 650 |
+
for track in X:
|
| 651 |
+
if self.is_DataFrame:
|
| 652 |
+
normalized_track = track.copy()
|
| 653 |
+
normalized_track.values = (track.values - self.data_mean_) / self.data_std_
|
| 654 |
+
else:
|
| 655 |
+
normalized_track = (track - self.data_mean_) / self.data_std_
|
| 656 |
+
|
| 657 |
+
Q.append(normalized_track)
|
| 658 |
+
|
| 659 |
+
if self.is_DataFrame:
|
| 660 |
+
return Q
|
| 661 |
+
else:
|
| 662 |
+
return np.array(Q)
|
| 663 |
+
|
| 664 |
+
def inverse_transform(self, X, copy=None):
|
| 665 |
+
Q = []
|
| 666 |
+
|
| 667 |
+
for track in X:
|
| 668 |
+
|
| 669 |
+
if self.is_DataFrame:
|
| 670 |
+
unnormalized_track = track.copy()
|
| 671 |
+
unnormalized_track.values = (track.values * self.data_std_) + self.data_mean_
|
| 672 |
+
else:
|
| 673 |
+
unnormalized_track = (track * self.data_std_) + self.data_mean_
|
| 674 |
+
|
| 675 |
+
Q.append(unnormalized_track)
|
| 676 |
+
|
| 677 |
+
if self.is_DataFrame:
|
| 678 |
+
return Q
|
| 679 |
+
else:
|
| 680 |
+
return np.array(Q)
|
| 681 |
+
|
| 682 |
+
class DownSampler(BaseEstimator, TransformerMixin):
|
| 683 |
+
def __init__(self, rate):
|
| 684 |
+
self.rate = rate
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def fit(self, X, y=None):
|
| 688 |
+
|
| 689 |
+
return self
|
| 690 |
+
|
| 691 |
+
def transform(self, X, y=None):
|
| 692 |
+
Q = []
|
| 693 |
+
|
| 694 |
+
for track in X:
|
| 695 |
+
#print(track.values.size)
|
| 696 |
+
#new_track = track.clone()
|
| 697 |
+
#new_track.values = track.values[0:-1:self.rate]
|
| 698 |
+
#print(new_track.values.size)
|
| 699 |
+
new_track = track[0:-1:self.rate]
|
| 700 |
+
Q.append(new_track)
|
| 701 |
+
|
| 702 |
+
return Q
|
| 703 |
+
|
| 704 |
+
def inverse_transform(self, X, copy=None):
|
| 705 |
+
return X
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
#TODO: JointsSelector (x)
|
| 709 |
+
#TODO: SegmentMaker
|
| 710 |
+
#TODO: DynamicFeaturesAdder
|
| 711 |
+
#TODO: ShapeFeaturesAdder
|
| 712 |
+
#TODO: DataFrameNumpier (x)
|
| 713 |
+
|
| 714 |
+
class TemplateTransform(BaseEstimator, TransformerMixin):
|
| 715 |
+
def __init__(self):
|
| 716 |
+
pass
|
| 717 |
+
|
| 718 |
+
def fit(self, X, y=None):
|
| 719 |
+
return self
|
| 720 |
+
|
| 721 |
+
def transform(self, X, y=None):
|
| 722 |
+
return X
|
| 723 |
+
|
| 724 |
+
class UnsupportedParamError(Exception):
|
| 725 |
+
def __init__(self, message):
|
| 726 |
+
self.message = message
|
dataloaders/pymo/rotation_tools.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Tools for Manipulating and Converting 3D Rotations
|
| 3 |
+
|
| 4 |
+
By Omid Alemi
|
| 5 |
+
Created: June 12, 2017
|
| 6 |
+
|
| 7 |
+
Adapted from that matlab file...
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
def deg2rad(x):
|
| 14 |
+
return x/180*math.pi
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def rad2deg(x):
|
| 18 |
+
return x/math.pi*180
|
| 19 |
+
|
| 20 |
+
class Rotation():
|
| 21 |
+
def __init__(self,rot, param_type, rotation_order, **params):
|
| 22 |
+
self.rotmat = []
|
| 23 |
+
self.rotation_order = rotation_order
|
| 24 |
+
if param_type == 'euler':
|
| 25 |
+
self._from_euler(rot[0],rot[1],rot[2], params)
|
| 26 |
+
elif param_type == 'expmap':
|
| 27 |
+
self._from_expmap(rot[0], rot[1], rot[2], params)
|
| 28 |
+
|
| 29 |
+
def _from_euler(self, alpha, beta, gamma, params):
|
| 30 |
+
'''Expecting degress'''
|
| 31 |
+
|
| 32 |
+
if params['from_deg']==True:
|
| 33 |
+
alpha = deg2rad(alpha)
|
| 34 |
+
beta = deg2rad(beta)
|
| 35 |
+
gamma = deg2rad(gamma)
|
| 36 |
+
|
| 37 |
+
ca = math.cos(alpha)
|
| 38 |
+
cb = math.cos(beta)
|
| 39 |
+
cg = math.cos(gamma)
|
| 40 |
+
sa = math.sin(alpha)
|
| 41 |
+
sb = math.sin(beta)
|
| 42 |
+
sg = math.sin(gamma)
|
| 43 |
+
|
| 44 |
+
Rx = np.asarray([[1, 0, 0],
|
| 45 |
+
[0, ca, sa],
|
| 46 |
+
[0, -sa, ca]
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
Ry = np.asarray([[cb, 0, -sb],
|
| 50 |
+
[0, 1, 0],
|
| 51 |
+
[sb, 0, cb]])
|
| 52 |
+
|
| 53 |
+
Rz = np.asarray([[cg, sg, 0],
|
| 54 |
+
[-sg, cg, 0],
|
| 55 |
+
[0, 0, 1]])
|
| 56 |
+
|
| 57 |
+
self.rotmat = np.eye(3)
|
| 58 |
+
|
| 59 |
+
############################ inner product rotation matrix in order defined at BVH file #########################
|
| 60 |
+
for axis in self.rotation_order :
|
| 61 |
+
if axis == 'X' :
|
| 62 |
+
self.rotmat = np.matmul(Rx, self.rotmat)
|
| 63 |
+
elif axis == 'Y':
|
| 64 |
+
self.rotmat = np.matmul(Ry, self.rotmat)
|
| 65 |
+
else :
|
| 66 |
+
self.rotmat = np.matmul(Rz, self.rotmat)
|
| 67 |
+
################################################################################################################
|
| 68 |
+
|
| 69 |
+
def _from_expmap(self, alpha, beta, gamma, params):
|
| 70 |
+
if (alpha == 0 and beta == 0 and gamma == 0):
|
| 71 |
+
self.rotmat = np.eye(3)
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
#TODO: Check exp map params
|
| 75 |
+
|
| 76 |
+
theta = np.linalg.norm([alpha, beta, gamma])
|
| 77 |
+
|
| 78 |
+
expmap = [alpha, beta, gamma] / theta
|
| 79 |
+
|
| 80 |
+
x = expmap[0]
|
| 81 |
+
y = expmap[1]
|
| 82 |
+
z = expmap[2]
|
| 83 |
+
|
| 84 |
+
s = math.sin(theta/2)
|
| 85 |
+
c = math.cos(theta/2)
|
| 86 |
+
|
| 87 |
+
self.rotmat = np.asarray([
|
| 88 |
+
[2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s],
|
| 89 |
+
[2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s],
|
| 90 |
+
[2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1]
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_euler_axis(self):
|
| 96 |
+
R = self.rotmat
|
| 97 |
+
theta = math.acos((self.rotmat.trace() - 1) / 2)
|
| 98 |
+
axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]])
|
| 99 |
+
axis = axis/(2*math.sin(theta))
|
| 100 |
+
return theta, axis
|
| 101 |
+
|
| 102 |
+
def to_expmap(self):
|
| 103 |
+
theta, axis = self.get_euler_axis()
|
| 104 |
+
rot_arr = theta * axis
|
| 105 |
+
if np.isnan(rot_arr).any():
|
| 106 |
+
rot_arr = [0, 0, 0]
|
| 107 |
+
return rot_arr
|
| 108 |
+
|
| 109 |
+
def to_euler(self, use_deg=False):
|
| 110 |
+
eulers = np.zeros((2, 3))
|
| 111 |
+
|
| 112 |
+
if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12:
|
| 113 |
+
#GIMBAL LOCK!
|
| 114 |
+
print('Gimbal')
|
| 115 |
+
if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12:
|
| 116 |
+
eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2])
|
| 117 |
+
eulers[:,1] = -math.pi/2
|
| 118 |
+
else:
|
| 119 |
+
eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2])
|
| 120 |
+
eulers[:,1] = math.pi/2
|
| 121 |
+
|
| 122 |
+
return eulers
|
| 123 |
+
|
| 124 |
+
theta = - math.asin(self.rotmat[2,0])
|
| 125 |
+
theta2 = math.pi - theta
|
| 126 |
+
|
| 127 |
+
# psi1, psi2
|
| 128 |
+
eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta))
|
| 129 |
+
eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2))
|
| 130 |
+
|
| 131 |
+
# theta1, theta2
|
| 132 |
+
eulers[0,1] = theta
|
| 133 |
+
eulers[1,1] = theta2
|
| 134 |
+
|
| 135 |
+
# phi1, phi2
|
| 136 |
+
eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta))
|
| 137 |
+
eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2))
|
| 138 |
+
|
| 139 |
+
if use_deg:
|
| 140 |
+
eulers = rad2deg(eulers)
|
| 141 |
+
|
| 142 |
+
return eulers
|
| 143 |
+
|
| 144 |
+
def to_quat(self):
|
| 145 |
+
#TODO
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
def __str__(self):
|
| 149 |
+
return "Rotation Matrix: \n " + self.rotmat.__str__()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
dataloaders/pymo/rotation_tools.py!
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Tools for Manipulating and Converting 3D Rotations
|
| 3 |
+
|
| 4 |
+
By Omid Alemi
|
| 5 |
+
Created: June 12, 2017
|
| 6 |
+
|
| 7 |
+
Adapted from that matlab file...
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
def deg2rad(x):
|
| 14 |
+
return x/180*math.pi
|
| 15 |
+
|
| 16 |
+
class Rotation():
|
| 17 |
+
def __init__(self,rot, param_type, **params):
|
| 18 |
+
self.rotmat = []
|
| 19 |
+
if param_type == 'euler':
|
| 20 |
+
self._from_euler(rot[0],rot[1],rot[2], params)
|
| 21 |
+
|
| 22 |
+
def _from_euler(self, alpha, beta, gamma, params):
|
| 23 |
+
'''Expecting degress'''
|
| 24 |
+
|
| 25 |
+
if params['from_deg']==True:
|
| 26 |
+
alpha = deg2rad(alpha)
|
| 27 |
+
beta = deg2rad(beta)
|
| 28 |
+
gamma = deg2rad(gamma)
|
| 29 |
+
|
| 30 |
+
Rx = np.asarray([[1, 0, 0],
|
| 31 |
+
[0, math.cos(alpha), -math.sin(alpha)],
|
| 32 |
+
[0, math.sin(alpha), math.cos(alpha)]
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
Ry = np.asarray([[math.cos(beta), 0, math.sin(beta)],
|
| 36 |
+
[0, 1, 0],
|
| 37 |
+
[-math.sin(beta), 0, math.cos(beta)]])
|
| 38 |
+
|
| 39 |
+
Rz = np.asarray([[math.cos(gamma), -math.sin(gamma), 0],
|
| 40 |
+
[math.sin(gamma), math.cos(gamma), 0],
|
| 41 |
+
[0, 0, 1]])
|
| 42 |
+
|
| 43 |
+
self.rotmat = np.matmul(np.matmul(Rz, Ry), Rx).T
|
| 44 |
+
|
| 45 |
+
def get_euler_axis(self):
|
| 46 |
+
R = self.rotmat
|
| 47 |
+
theta = math.acos((self.rotmat.trace() - 1) / 2)
|
| 48 |
+
axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]])
|
| 49 |
+
axis = axis/(2*math.sin(theta))
|
| 50 |
+
return theta, axis
|
| 51 |
+
|
| 52 |
+
def to_expmap(self):
|
| 53 |
+
theta, axis = self.get_euler_axis()
|
| 54 |
+
rot_arr = theta * axis
|
| 55 |
+
if np.isnan(rot_arr).any():
|
| 56 |
+
rot_arr = [0, 0, 0]
|
| 57 |
+
return rot_arr
|
| 58 |
+
|
| 59 |
+
def to_euler(self):
|
| 60 |
+
#TODO
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
def to_quat(self):
|
| 64 |
+
#TODO
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
dataloaders/pymo/viz_tools.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import IPython
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def save_fig(fig_id, tight_layout=True):
|
| 8 |
+
if tight_layout:
|
| 9 |
+
plt.tight_layout()
|
| 10 |
+
plt.savefig(fig_id + '.png', format='png', dpi=300)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
|
| 14 |
+
if ax is None:
|
| 15 |
+
fig = plt.figure(figsize=figsize)
|
| 16 |
+
ax = fig.add_subplot(111)
|
| 17 |
+
|
| 18 |
+
if joints is None:
|
| 19 |
+
joints_to_draw = mocap_track.skeleton.keys()
|
| 20 |
+
else:
|
| 21 |
+
joints_to_draw = joints
|
| 22 |
+
|
| 23 |
+
if data is None:
|
| 24 |
+
df = mocap_track.values
|
| 25 |
+
else:
|
| 26 |
+
df = data
|
| 27 |
+
|
| 28 |
+
for joint in joints_to_draw:
|
| 29 |
+
ax.scatter(x=df['%s_Xposition'%joint][frame],
|
| 30 |
+
y=df['%s_Yposition'%joint][frame],
|
| 31 |
+
alpha=0.6, c='b', marker='o')
|
| 32 |
+
|
| 33 |
+
parent_x = df['%s_Xposition'%joint][frame]
|
| 34 |
+
parent_y = df['%s_Yposition'%joint][frame]
|
| 35 |
+
|
| 36 |
+
children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
|
| 37 |
+
|
| 38 |
+
for c in children_to_draw:
|
| 39 |
+
child_x = df['%s_Xposition'%c][frame]
|
| 40 |
+
child_y = df['%s_Yposition'%c][frame]
|
| 41 |
+
ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2)
|
| 42 |
+
|
| 43 |
+
if draw_names:
|
| 44 |
+
ax.annotate(joint,
|
| 45 |
+
(df['%s_Xposition'%joint][frame] + 0.1,
|
| 46 |
+
df['%s_Yposition'%joint][frame] + 0.1))
|
| 47 |
+
|
| 48 |
+
return ax
|
| 49 |
+
|
| 50 |
+
def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
|
| 51 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 52 |
+
|
| 53 |
+
if ax is None:
|
| 54 |
+
fig = plt.figure(figsize=figsize)
|
| 55 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 56 |
+
|
| 57 |
+
if joints is None:
|
| 58 |
+
joints_to_draw = mocap_track.skeleton.keys()
|
| 59 |
+
else:
|
| 60 |
+
joints_to_draw = joints
|
| 61 |
+
|
| 62 |
+
if data is None:
|
| 63 |
+
df = mocap_track.values
|
| 64 |
+
else:
|
| 65 |
+
df = data
|
| 66 |
+
|
| 67 |
+
for joint in joints_to_draw:
|
| 68 |
+
parent_x = df['%s_Xposition'%joint][frame]
|
| 69 |
+
parent_y = df['%s_Zposition'%joint][frame]
|
| 70 |
+
parent_z = df['%s_Yposition'%joint][frame]
|
| 71 |
+
# ^ In mocaps, Y is the up-right axis
|
| 72 |
+
|
| 73 |
+
ax.scatter(xs=parent_x,
|
| 74 |
+
ys=parent_y,
|
| 75 |
+
zs=parent_z,
|
| 76 |
+
alpha=0.6, c='b', marker='o')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
|
| 80 |
+
|
| 81 |
+
for c in children_to_draw:
|
| 82 |
+
child_x = df['%s_Xposition'%c][frame]
|
| 83 |
+
child_y = df['%s_Zposition'%c][frame]
|
| 84 |
+
child_z = df['%s_Yposition'%c][frame]
|
| 85 |
+
# ^ In mocaps, Y is the up-right axis
|
| 86 |
+
|
| 87 |
+
ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black')
|
| 88 |
+
|
| 89 |
+
if draw_names:
|
| 90 |
+
ax.text(x=parent_x + 0.1,
|
| 91 |
+
y=parent_y + 0.1,
|
| 92 |
+
z=parent_z + 0.1,
|
| 93 |
+
s=joint,
|
| 94 |
+
color='rgba(0,0,0,0.9)')
|
| 95 |
+
|
| 96 |
+
return ax
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)):
|
| 100 |
+
if ax is None:
|
| 101 |
+
fig = plt.figure(figsize=figsize)
|
| 102 |
+
ax = fig.add_subplot(111)
|
| 103 |
+
|
| 104 |
+
if data is None:
|
| 105 |
+
data = mocap_track.values
|
| 106 |
+
|
| 107 |
+
for frame in range(0, data.shape[0], 4):
|
| 108 |
+
# draw_stickfigure(mocap_track, f, data=data, ax=ax)
|
| 109 |
+
|
| 110 |
+
for joint in mocap_track.skeleton.keys():
|
| 111 |
+
children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
|
| 112 |
+
|
| 113 |
+
parent_x = data['%s_Xposition'%joint][frame]
|
| 114 |
+
parent_y = data['%s_Yposition'%joint][frame]
|
| 115 |
+
|
| 116 |
+
frame_alpha = frame/data.shape[0]
|
| 117 |
+
|
| 118 |
+
for c in children_to_draw:
|
| 119 |
+
child_x = data['%s_Xposition'%c][frame]
|
| 120 |
+
child_y = data['%s_Yposition'%c][frame]
|
| 121 |
+
|
| 122 |
+
ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25):
|
| 127 |
+
fig = plt.figure(figsize=(16,4))
|
| 128 |
+
ax = plt.subplot2grid((1,8),(0,0))
|
| 129 |
+
ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest')
|
| 130 |
+
|
| 131 |
+
ax = plt.subplot2grid((1,8),(0,1), colspan=7)
|
| 132 |
+
for frame in range(feature_to_viz.shape[0]):
|
| 133 |
+
frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2
|
| 134 |
+
|
| 135 |
+
for joint_i, joint in enumerate(mocap_track.skeleton.keys()):
|
| 136 |
+
children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
|
| 137 |
+
|
| 138 |
+
parent_x = data['%s_Xposition'%joint][frame] + frame * gap
|
| 139 |
+
parent_y = data['%s_Yposition'%joint][frame]
|
| 140 |
+
|
| 141 |
+
ax.scatter(x=parent_x,
|
| 142 |
+
y=parent_y,
|
| 143 |
+
alpha=0.6,
|
| 144 |
+
cmap='RdBu',
|
| 145 |
+
c=feature_to_viz[frame][joint_i] * 10000,
|
| 146 |
+
marker='o',
|
| 147 |
+
s = abs(feature_to_viz[frame][joint_i] * 10000))
|
| 148 |
+
plt.axis('off')
|
| 149 |
+
for c in children_to_draw:
|
| 150 |
+
child_x = data['%s_Xposition'%c][frame] + frame * gap
|
| 151 |
+
child_y = data['%s_Yposition'%c][frame]
|
| 152 |
+
|
| 153 |
+
ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def print_skel(X):
|
| 157 |
+
stack = [X.root_name]
|
| 158 |
+
tab=0
|
| 159 |
+
while stack:
|
| 160 |
+
joint = stack.pop()
|
| 161 |
+
tab = len(stack)
|
| 162 |
+
print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent']))
|
| 163 |
+
for c in X.skeleton[joint]['children']:
|
| 164 |
+
stack.append(c)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'):
|
| 168 |
+
if mf == 'bvh':
|
| 169 |
+
bw = BVHWriter()
|
| 170 |
+
with open('test.bvh', 'w') as ofile:
|
| 171 |
+
bw.write(mocap, ofile)
|
| 172 |
+
|
| 173 |
+
filepath = '../notebooks/test.bvh'
|
| 174 |
+
elif mf == 'pos':
|
| 175 |
+
c = list(mocap.values.columns)
|
| 176 |
+
|
| 177 |
+
for cc in c:
|
| 178 |
+
if 'rotation' in cc:
|
| 179 |
+
c.remove(cc)
|
| 180 |
+
mocap.values.to_csv('test.csv', index=False, columns=c)
|
| 181 |
+
|
| 182 |
+
filepath = '../notebooks/test.csv'
|
| 183 |
+
else:
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time)
|
| 187 |
+
iframe = '<iframe src=' + url + ' width="100%" height=500></iframe>'
|
| 188 |
+
link = '<a href=%s target="_blank">New Window</a>'%url
|
| 189 |
+
return IPython.display.HTML(iframe+link)
|
| 190 |
+
|
| 191 |
+
def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None):
|
| 192 |
+
data_template = 'var dataBuffer = `$$DATA$$`;'
|
| 193 |
+
data_template += 'var metadata = $$META$$;'
|
| 194 |
+
data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);'
|
| 195 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if base_url is None:
|
| 199 |
+
base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html')
|
| 200 |
+
|
| 201 |
+
# print(dir_path)
|
| 202 |
+
|
| 203 |
+
if mf == 'bvh':
|
| 204 |
+
pass
|
| 205 |
+
elif mf == 'pos':
|
| 206 |
+
cols = list(mocap.values.columns)
|
| 207 |
+
for c in cols:
|
| 208 |
+
if 'rotation' in c:
|
| 209 |
+
cols.remove(c)
|
| 210 |
+
|
| 211 |
+
data_csv = mocap.values.to_csv(index=False, columns=cols)
|
| 212 |
+
|
| 213 |
+
if meta is not None:
|
| 214 |
+
lines = [','.join(item) for item in meta.astype('str')]
|
| 215 |
+
meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']'
|
| 216 |
+
else:
|
| 217 |
+
meta_csv = '[]'
|
| 218 |
+
|
| 219 |
+
data_assigned = data_template.replace('$$DATA$$', data_csv)
|
| 220 |
+
data_assigned = data_assigned.replace('$$META$$', meta_csv)
|
| 221 |
+
data_assigned = data_assigned.replace('$$CZ$$', str(camera_z))
|
| 222 |
+
data_assigned = data_assigned.replace('$$SCALE$$', str(scale))
|
| 223 |
+
data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time))
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile:
|
| 231 |
+
oFile.write(data_assigned)
|
| 232 |
+
|
| 233 |
+
url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale)
|
| 234 |
+
iframe = '<iframe frameborder="0" src=' + url + ' width="100%" height=500></iframe>'
|
| 235 |
+
link = '<a href=%s target="_blank">New Window</a>'%url
|
| 236 |
+
return IPython.display.HTML(iframe+link)
|
dataloaders/pymo/writers.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
class BVHWriter():
|
| 5 |
+
def __init__(self):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
def write(self, X, ofile):
|
| 9 |
+
|
| 10 |
+
# Writing the skeleton info
|
| 11 |
+
ofile.write('HIERARCHY\n')
|
| 12 |
+
|
| 13 |
+
self.motions_ = []
|
| 14 |
+
self._printJoint(X, X.root_name, 0, ofile)
|
| 15 |
+
|
| 16 |
+
# Writing the motion header
|
| 17 |
+
ofile.write('MOTION\n')
|
| 18 |
+
ofile.write('Frames: %d\n'%X.values.shape[0])
|
| 19 |
+
ofile.write('Frame Time: %f\n'%X.framerate)
|
| 20 |
+
|
| 21 |
+
# Writing the data
|
| 22 |
+
self.motions_ = np.asarray(self.motions_).T
|
| 23 |
+
lines = [" ".join(item) for item in self.motions_.astype(str)]
|
| 24 |
+
ofile.write("".join("%s\n"%l for l in lines))
|
| 25 |
+
|
| 26 |
+
def _printJoint(self, X, joint, tab, ofile):
|
| 27 |
+
|
| 28 |
+
if X.skeleton[joint]['parent'] == None:
|
| 29 |
+
ofile.write('ROOT %s\n'%joint)
|
| 30 |
+
elif len(X.skeleton[joint]['children']) > 0:
|
| 31 |
+
ofile.write('%sJOINT %s\n'%('\t'*(tab), joint))
|
| 32 |
+
else:
|
| 33 |
+
ofile.write('%sEnd site\n'%('\t'*(tab)))
|
| 34 |
+
|
| 35 |
+
ofile.write('%s{\n'%('\t'*(tab)))
|
| 36 |
+
|
| 37 |
+
ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1),
|
| 38 |
+
X.skeleton[joint]['offsets'][0],
|
| 39 |
+
X.skeleton[joint]['offsets'][1],
|
| 40 |
+
X.skeleton[joint]['offsets'][2]))
|
| 41 |
+
channels = X.skeleton[joint]['channels']
|
| 42 |
+
n_channels = len(channels)
|
| 43 |
+
|
| 44 |
+
if n_channels > 0:
|
| 45 |
+
for ch in channels:
|
| 46 |
+
self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values))
|
| 47 |
+
|
| 48 |
+
if len(X.skeleton[joint]['children']) > 0:
|
| 49 |
+
ch_str = ''.join(' %s'*n_channels%tuple(channels))
|
| 50 |
+
ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str))
|
| 51 |
+
|
| 52 |
+
for c in X.skeleton[joint]['children']:
|
| 53 |
+
self._printJoint(X, c, tab+1, ofile)
|
| 54 |
+
|
| 55 |
+
ofile.write('%s}\n'%('\t'*(tab)))
|
dataloaders/utils/__pycache__/audio_features.cpython-312.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
dataloaders/utils/__pycache__/other_tools.cpython-312.pyc
ADDED
|
Binary file (37.2 kB). View file
|
|
|
dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
dataloaders/utils/audio_features.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""modified from https://github.com/yesheng-THU/GFGE/blob/main/data_processing/audio_features.py"""
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import scipy.io.wavfile as wav
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import copy
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
from numpy.lib import stride_tricks
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
# Import Wav2Vec2Model to make it available for other modules
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 19 |
+
from models.utils.wav2vec import Wav2Vec2Model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def process_audio_data(audio_file, args, data, f_name, selected_file):
|
| 24 |
+
"""Process audio data with support for different representations."""
|
| 25 |
+
logger.info(f"# ---- Building cache for Audio {f_name} ---- #")
|
| 26 |
+
|
| 27 |
+
if not os.path.exists(audio_file):
|
| 28 |
+
logger.warning(f"# ---- file not found for Audio {f_name}, skip all files with the same id ---- #")
|
| 29 |
+
selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True)
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy")
|
| 33 |
+
|
| 34 |
+
if args.audio_rep == "onset+amplitude" and os.path.exists(audio_save_path):
|
| 35 |
+
data['audio'] = np.load(audio_save_path)
|
| 36 |
+
logger.warning(f"# ---- file found cache for Audio {f_name} ---- #")
|
| 37 |
+
|
| 38 |
+
elif args.audio_rep == "onset+amplitude":
|
| 39 |
+
data['audio'] = calculate_onset_amplitude(audio_file, args.audio_sr, audio_save_path)
|
| 40 |
+
|
| 41 |
+
elif args.audio_rep == "mfcc":
|
| 42 |
+
audio_data, _ = librosa.load(audio_file)
|
| 43 |
+
data['audio'] = librosa.feature.melspectrogram(
|
| 44 |
+
y=audio_data,
|
| 45 |
+
sr=args.audio_sr,
|
| 46 |
+
n_mels=128,
|
| 47 |
+
hop_length=int(args.audio_sr/args.audio_fps)
|
| 48 |
+
).transpose(1, 0)
|
| 49 |
+
|
| 50 |
+
if args.audio_norm and args.audio_rep == "wave16k":
|
| 51 |
+
data['audio'] = (data['audio'] - args.mean_audio) / args.std_audio
|
| 52 |
+
|
| 53 |
+
return data
|
| 54 |
+
|
| 55 |
+
def calculate_onset_amplitude(audio_file, audio_sr, save_path):
|
| 56 |
+
"""Calculate onset and amplitude features from audio file."""
|
| 57 |
+
audio_data, sr = librosa.load(audio_file)
|
| 58 |
+
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=audio_sr)
|
| 59 |
+
|
| 60 |
+
# Calculate amplitude envelope
|
| 61 |
+
frame_length = 1024
|
| 62 |
+
shape = (audio_data.shape[-1] - frame_length + 1, frame_length)
|
| 63 |
+
strides = (audio_data.strides[-1], audio_data.strides[-1])
|
| 64 |
+
rolling_view = stride_tricks.as_strided(audio_data, shape=shape, strides=strides)
|
| 65 |
+
amplitude_envelope = np.max(np.abs(rolling_view), axis=1)
|
| 66 |
+
amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1])
|
| 67 |
+
|
| 68 |
+
# Calculate onset
|
| 69 |
+
audio_onset_f = librosa.onset.onset_detect(y=audio_data, sr=audio_sr, units='frames')
|
| 70 |
+
onset_array = np.zeros(len(audio_data), dtype=float)
|
| 71 |
+
onset_array[audio_onset_f] = 1.0
|
| 72 |
+
|
| 73 |
+
# Combine features
|
| 74 |
+
features = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1)
|
| 75 |
+
|
| 76 |
+
# Save features
|
| 77 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 78 |
+
np.save(save_path, features)
|
| 79 |
+
|
| 80 |
+
return features
|
dataloaders/utils/data_sample.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
def sample_from_clip(
|
| 7 |
+
lmdb_manager, audio_file, audio_each_file, pose_each_file, trans_each_file,
|
| 8 |
+
trans_v_each_file, shape_each_file, facial_each_file, word_each_file,
|
| 9 |
+
vid_each_file, emo_each_file, sem_each_file, args, ori_stride, ori_length,
|
| 10 |
+
disable_filtering, clean_first_seconds, clean_final_seconds, is_test,
|
| 11 |
+
n_out_samples):
|
| 12 |
+
"""Sample clips from the data according to specified parameters."""
|
| 13 |
+
|
| 14 |
+
round_seconds_skeleton = pose_each_file.shape[0] // args.pose_fps
|
| 15 |
+
|
| 16 |
+
# Calculate timing information
|
| 17 |
+
timing_info = calculate_timing_info(
|
| 18 |
+
audio_each_file, facial_each_file, round_seconds_skeleton,
|
| 19 |
+
args.audio_fps, args.pose_fps, args.audio_sr, args.audio_rep
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
round_seconds_skeleton = timing_info['final_seconds']
|
| 23 |
+
|
| 24 |
+
# Calculate clip boundaries
|
| 25 |
+
clip_info = calculate_clip_boundaries(
|
| 26 |
+
round_seconds_skeleton, clean_first_seconds, clean_final_seconds,
|
| 27 |
+
args.audio_fps, args.pose_fps
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
n_filtered_out = defaultdict(int)
|
| 31 |
+
|
| 32 |
+
# Process each training length ratio
|
| 33 |
+
for ratio in args.multi_length_training:
|
| 34 |
+
processed_data = process_data_with_ratio(
|
| 35 |
+
ori_stride, ori_length, ratio, clip_info, args, is_test,
|
| 36 |
+
audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,
|
| 37 |
+
shape_each_file, facial_each_file, word_each_file, vid_each_file,
|
| 38 |
+
emo_each_file, sem_each_file, audio_file,
|
| 39 |
+
lmdb_manager, n_out_samples
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
for type_key, count in processed_data['filtered_counts'].items():
|
| 43 |
+
n_filtered_out[type_key] += count
|
| 44 |
+
|
| 45 |
+
n_out_samples = processed_data['n_out_samples']
|
| 46 |
+
|
| 47 |
+
return n_filtered_out, n_out_samples
|
| 48 |
+
|
| 49 |
+
def calculate_timing_info(audio_data, facial_data, round_seconds_skeleton,
|
| 50 |
+
audio_fps, pose_fps, audio_sr, audio_rep):
|
| 51 |
+
"""Calculate timing information for the data."""
|
| 52 |
+
if audio_data is not None:
|
| 53 |
+
if audio_rep != "wave16k":
|
| 54 |
+
round_seconds_audio = len(audio_data) // audio_fps
|
| 55 |
+
elif audio_rep == "mfcc":
|
| 56 |
+
round_seconds_audio = audio_data.shape[0] // audio_fps
|
| 57 |
+
else:
|
| 58 |
+
round_seconds_audio = audio_data.shape[0] // audio_sr
|
| 59 |
+
|
| 60 |
+
if facial_data is not None:
|
| 61 |
+
round_seconds_facial = facial_data.shape[0] // pose_fps
|
| 62 |
+
logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s")
|
| 63 |
+
final_seconds = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 64 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial)
|
| 65 |
+
if final_seconds != max_round:
|
| 66 |
+
logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s")
|
| 67 |
+
else:
|
| 68 |
+
logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s")
|
| 69 |
+
final_seconds = min(round_seconds_audio, round_seconds_skeleton)
|
| 70 |
+
max_round = max(round_seconds_audio, round_seconds_skeleton)
|
| 71 |
+
if final_seconds != max_round:
|
| 72 |
+
logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s")
|
| 73 |
+
else:
|
| 74 |
+
final_seconds = round_seconds_skeleton
|
| 75 |
+
|
| 76 |
+
return {
|
| 77 |
+
'final_seconds': final_seconds
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def calculate_clip_boundaries(round_seconds, clean_first_seconds, clean_final_seconds,
|
| 81 |
+
audio_fps, pose_fps):
|
| 82 |
+
"""Calculate the boundaries for clip sampling."""
|
| 83 |
+
clip_s_t = clean_first_seconds
|
| 84 |
+
clip_e_t = round_seconds - clean_final_seconds
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
'clip_s_t': clip_s_t,
|
| 88 |
+
'clip_e_t': clip_e_t,
|
| 89 |
+
'clip_s_f_audio': audio_fps * clip_s_t,
|
| 90 |
+
'clip_e_f_audio': clip_e_t * audio_fps,
|
| 91 |
+
'clip_s_f_pose': clip_s_t * pose_fps,
|
| 92 |
+
'clip_e_f_pose': clip_e_t * pose_fps
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def process_data_with_ratio(ori_stride, ori_length, ratio, clip_info, args, is_test,
|
| 96 |
+
audio_data, pose_data, trans_data, trans_v_data,
|
| 97 |
+
shape_data, facial_data, word_data, vid_data,
|
| 98 |
+
emo_data, sem_data, audio_file,
|
| 99 |
+
lmdb_manager, n_out_samples):
|
| 100 |
+
"""Process data with a specific training length ratio."""
|
| 101 |
+
|
| 102 |
+
if is_test and not args.test_clip:
|
| 103 |
+
cut_length = clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose']
|
| 104 |
+
args.stride = cut_length
|
| 105 |
+
max_length = cut_length
|
| 106 |
+
else:
|
| 107 |
+
args.stride = int(ratio * ori_stride)
|
| 108 |
+
cut_length = int(ori_length * ratio)
|
| 109 |
+
|
| 110 |
+
num_subdivision = math.floor(
|
| 111 |
+
(clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose'] - cut_length) / args.stride
|
| 112 |
+
) + 1
|
| 113 |
+
|
| 114 |
+
logger.info(f"pose from frame {clip_info['clip_s_f_pose']} to {clip_info['clip_e_f_pose']}, length {cut_length}")
|
| 115 |
+
logger.info(f"{num_subdivision} clips is expected with stride {args.stride}")
|
| 116 |
+
|
| 117 |
+
if audio_data is not None:
|
| 118 |
+
audio_short_length = math.floor(cut_length / args.pose_fps * args.audio_fps)
|
| 119 |
+
logger.info(f"audio from frame {clip_info['clip_s_f_audio']} to {clip_info['clip_e_f_audio']}, length {audio_short_length}")
|
| 120 |
+
|
| 121 |
+
# Process subdivisions
|
| 122 |
+
filtered_counts = defaultdict(int)
|
| 123 |
+
for i in range(num_subdivision):
|
| 124 |
+
sample_data = extract_sample_data(
|
| 125 |
+
i, clip_info, cut_length, args,
|
| 126 |
+
audio_data, pose_data, trans_data, trans_v_data,
|
| 127 |
+
shape_data, facial_data, word_data, vid_data,
|
| 128 |
+
emo_data, sem_data, audio_file,
|
| 129 |
+
audio_short_length
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if sample_data['pose'].any() is not None:
|
| 133 |
+
lmdb_manager.add_sample([
|
| 134 |
+
sample_data['pose'], sample_data['audio'], sample_data['facial'],
|
| 135 |
+
sample_data['shape'], sample_data['word'], sample_data['emo'],
|
| 136 |
+
sample_data['sem'], sample_data['vid'], sample_data['trans'],
|
| 137 |
+
sample_data['trans_v'], sample_data['audio_name']
|
| 138 |
+
])
|
| 139 |
+
n_out_samples += 1
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
'filtered_counts': filtered_counts,
|
| 143 |
+
'n_out_samples': n_out_samples
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def extract_sample_data(idx, clip_info, cut_length, args,
|
| 147 |
+
audio_data, pose_data, trans_data, trans_v_data,
|
| 148 |
+
shape_data, facial_data, word_data, vid_data,
|
| 149 |
+
emo_data, sem_data, audio_file,
|
| 150 |
+
audio_short_length):
|
| 151 |
+
"""Extract a single sample from the data."""
|
| 152 |
+
start_idx = clip_info['clip_s_f_pose'] + idx * args.stride
|
| 153 |
+
fin_idx = start_idx + cut_length
|
| 154 |
+
|
| 155 |
+
sample_data = {
|
| 156 |
+
'pose': pose_data[start_idx:fin_idx],
|
| 157 |
+
'trans': trans_data[start_idx:fin_idx],
|
| 158 |
+
'trans_v': trans_v_data[start_idx:fin_idx],
|
| 159 |
+
'shape': shape_data[start_idx:fin_idx],
|
| 160 |
+
'facial': facial_data[start_idx:fin_idx] if args.facial_rep is not None else np.array([-1]),
|
| 161 |
+
'word': word_data[start_idx:fin_idx] if args.word_rep is not None else np.array([-1]),
|
| 162 |
+
'emo': emo_data[start_idx:fin_idx] if args.emo_rep is not None else np.array([-1]),
|
| 163 |
+
'sem': sem_data[start_idx:fin_idx] if args.sem_rep is not None else np.array([-1]),
|
| 164 |
+
'vid': vid_data[start_idx:fin_idx] if args.id_rep is not None else np.array([-1]),
|
| 165 |
+
'audio_name': audio_file
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if audio_data is not None:
|
| 169 |
+
audio_start = clip_info['clip_s_f_audio'] + math.floor(idx * args.stride * args.audio_fps / args.pose_fps)
|
| 170 |
+
audio_end = audio_start + audio_short_length
|
| 171 |
+
sample_data['audio'] = audio_data[audio_start:audio_end]
|
| 172 |
+
else:
|
| 173 |
+
sample_data['audio'] = np.array([-1])
|
| 174 |
+
|
| 175 |
+
return sample_data
|
dataloaders/utils/mis_features.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# semantic_utils.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from loguru import logger
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def process_semantic_data(sem_file, args, data, f_name):
|
| 8 |
+
"""Process semantic representation data."""
|
| 9 |
+
logger.info(f"# ---- Building cache for Semantic {f_name} ---- #")
|
| 10 |
+
|
| 11 |
+
if not os.path.exists(sem_file):
|
| 12 |
+
logger.warning(f"# ---- file not found for Semantic {f_name} ---- #")
|
| 13 |
+
return None
|
| 14 |
+
|
| 15 |
+
sem_all = pd.read_csv(sem_file,
|
| 16 |
+
sep='\t',
|
| 17 |
+
names=["name", "start_time", "end_time", "duration", "score", "keywords"])
|
| 18 |
+
|
| 19 |
+
sem_data = []
|
| 20 |
+
for i in range(data['pose'].shape[0]):
|
| 21 |
+
current_time = i/args.pose_fps
|
| 22 |
+
found_score = False
|
| 23 |
+
|
| 24 |
+
for _, row in sem_all.iterrows():
|
| 25 |
+
if row['start_time'] <= current_time <= row['end_time']:
|
| 26 |
+
sem_data.append(row['score'])
|
| 27 |
+
found_score = True
|
| 28 |
+
break
|
| 29 |
+
|
| 30 |
+
if not found_score:
|
| 31 |
+
sem_data.append(0.0)
|
| 32 |
+
|
| 33 |
+
data['sem'] = np.array(sem_data)
|
| 34 |
+
return data
|
| 35 |
+
|
| 36 |
+
def process_emotion_data(f_name, data, args):
|
| 37 |
+
"""Process emotion representation data."""
|
| 38 |
+
logger.info(f"# ---- Building cache for Emotion {f_name} ---- #")
|
| 39 |
+
|
| 40 |
+
rtype, start = int(f_name.split('_')[3]), int(f_name.split('_')[3])
|
| 41 |
+
if rtype in [0, 2, 4, 6]:
|
| 42 |
+
if 1 <= start <= 64:
|
| 43 |
+
score = 0
|
| 44 |
+
elif 65 <= start <= 72:
|
| 45 |
+
score = 1
|
| 46 |
+
elif 73 <= start <= 80:
|
| 47 |
+
score = 2
|
| 48 |
+
elif 81 <= start <= 86:
|
| 49 |
+
score = 3
|
| 50 |
+
elif 87 <= start <= 94:
|
| 51 |
+
score = 4
|
| 52 |
+
elif 95 <= start <= 102:
|
| 53 |
+
score = 5
|
| 54 |
+
elif 103 <= start <= 110:
|
| 55 |
+
score = 6
|
| 56 |
+
elif 111 <= start <= 118:
|
| 57 |
+
score = 7
|
| 58 |
+
else:
|
| 59 |
+
score = 0
|
| 60 |
+
else:
|
| 61 |
+
score = 0
|
| 62 |
+
|
| 63 |
+
data['emo'] = np.repeat(np.array(score).reshape(1, 1), data['pose'].shape[0], axis=0)
|
| 64 |
+
return data
|
dataloaders/utils/motion_rep_transfer.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import smplx
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from . import rotation_conversions as rc
|
| 5 |
+
import os
|
| 6 |
+
import wget
|
| 7 |
+
|
| 8 |
+
download_path = "./datasets/hub"
|
| 9 |
+
smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
|
| 10 |
+
if not os.path.exists(smplx_model_dir):
|
| 11 |
+
smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
|
| 12 |
+
os.makedirs(smplx_model_dir, exist_ok=True)
|
| 13 |
+
if not os.path.exists(smplx_model_file_path):
|
| 14 |
+
print(f"Downloading {smplx_model_file_path}")
|
| 15 |
+
wget.download(
|
| 16 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
|
| 17 |
+
smplx_model_file_path,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
smplx_model = smplx.create(
|
| 21 |
+
"./datasets/hub/smplx_models/",
|
| 22 |
+
model_type='smplx',
|
| 23 |
+
gender='NEUTRAL_2020',
|
| 24 |
+
use_face_contour=False,
|
| 25 |
+
num_betas=300,
|
| 26 |
+
num_expression_coeffs=100,
|
| 27 |
+
ext='npz',
|
| 28 |
+
use_pca=False,
|
| 29 |
+
).eval()
|
| 30 |
+
|
| 31 |
+
def get_motion_rep_tensor(motion_tensor, pose_fps=30, device="cuda", betas=None):
|
| 32 |
+
global smplx_model
|
| 33 |
+
smplx_model = smplx_model.to(device)
|
| 34 |
+
bs, n, _ = motion_tensor.shape
|
| 35 |
+
motion_tensor = motion_tensor.float().to(device)
|
| 36 |
+
motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
|
| 37 |
+
betas = torch.zeros(n, 300, device=device) if betas is None else betas.to(device).unsqueeze(0).repeat(n, 1)
|
| 38 |
+
output = smplx_model(
|
| 39 |
+
betas=torch.zeros(bs * n, 300, device=device),
|
| 40 |
+
transl=torch.zeros(bs * n, 3, device=device),
|
| 41 |
+
expression=torch.zeros(bs * n, 100, device=device),
|
| 42 |
+
jaw_pose=torch.zeros(bs * n, 3, device=device),
|
| 43 |
+
global_orient=torch.zeros(bs * n, 3, device=device),
|
| 44 |
+
body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
|
| 45 |
+
left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
|
| 46 |
+
right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
|
| 47 |
+
return_joints=True,
|
| 48 |
+
leye_pose=torch.zeros(bs * n, 3, device=device),
|
| 49 |
+
reye_pose=torch.zeros(bs * n, 3, device=device),
|
| 50 |
+
)
|
| 51 |
+
joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
|
| 52 |
+
dt = 1 / pose_fps
|
| 53 |
+
init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
|
| 54 |
+
middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
|
| 55 |
+
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
| 56 |
+
vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
|
| 57 |
+
position = joints
|
| 58 |
+
rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
|
| 59 |
+
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)
|
| 60 |
+
init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
|
| 61 |
+
middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
|
| 62 |
+
final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
|
| 63 |
+
angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)
|
| 64 |
+
rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
|
| 65 |
+
return {
|
| 66 |
+
"position": position,
|
| 67 |
+
"velocity": vel,
|
| 68 |
+
"rotation": rot6d,
|
| 69 |
+
"axis_angle": motion_tensor,
|
| 70 |
+
"angular_velocity": angular_velocity,
|
| 71 |
+
"rep15d": rep15d,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
def get_motion_rep_numpy(poses_np, pose_fps=30, device="cuda", expressions=None, expression_only=False, betas=None):
|
| 75 |
+
# motion["poses"] is expected to be numpy array of shape (n, 165)
|
| 76 |
+
# (n, 55*3), axis-angle for 55 joints
|
| 77 |
+
global smplx_model
|
| 78 |
+
smplx_model = smplx_model.to(device)
|
| 79 |
+
n = poses_np.shape[0]
|
| 80 |
+
|
| 81 |
+
# Convert numpy to torch tensor for SMPL-X forward pass
|
| 82 |
+
poses_ts = torch.from_numpy(poses_np).float().to(device).unsqueeze(0) # (1, n, 165)
|
| 83 |
+
poses_ts_reshaped = poses_ts.reshape(-1, 165) # (n, 165)
|
| 84 |
+
betas = torch.zeros(n, 300, device=device) if betas is None else torch.from_numpy(betas).to(device).unsqueeze(0).repeat(n, 1)
|
| 85 |
+
if expressions is not None and expression_only:
|
| 86 |
+
# print("xx")
|
| 87 |
+
expressions = torch.from_numpy(expressions).float().to(device)
|
| 88 |
+
output = smplx_model(
|
| 89 |
+
betas=betas,
|
| 90 |
+
transl=torch.zeros(n, 3, device=device),
|
| 91 |
+
expression=expressions,
|
| 92 |
+
jaw_pose=poses_ts_reshaped[:, 22 * 3:23 * 3],
|
| 93 |
+
global_orient=torch.zeros(n, 3, device=device),
|
| 94 |
+
body_pose=torch.zeros(n, 21*3, device=device),
|
| 95 |
+
left_hand_pose=torch.zeros(n, 15*3, device=device),
|
| 96 |
+
right_hand_pose=torch.zeros(n, 15*3, device=device),
|
| 97 |
+
return_joints=True,
|
| 98 |
+
leye_pose=torch.zeros(n, 3, device=device),
|
| 99 |
+
reye_pose=torch.zeros(n, 3, device=device),
|
| 100 |
+
)
|
| 101 |
+
joints = output["vertices"].detach().cpu().numpy().reshape(n, -1)
|
| 102 |
+
return {"vertices": joints}
|
| 103 |
+
|
| 104 |
+
# Run smplx model to get joints
|
| 105 |
+
output = smplx_model(
|
| 106 |
+
betas=betas,
|
| 107 |
+
transl=torch.zeros(n, 3, device=device),
|
| 108 |
+
expression=torch.zeros(n, 100, device=device),
|
| 109 |
+
jaw_pose=torch.zeros(n, 3, device=device),
|
| 110 |
+
global_orient=torch.zeros(n, 3, device=device),
|
| 111 |
+
body_pose=poses_ts_reshaped[:, 3:21 * 3 + 3],
|
| 112 |
+
left_hand_pose=poses_ts_reshaped[:, 25 * 3:40 * 3],
|
| 113 |
+
right_hand_pose=poses_ts_reshaped[:, 40 * 3:55 * 3],
|
| 114 |
+
return_joints=True,
|
| 115 |
+
leye_pose=torch.zeros(n, 3, device=device),
|
| 116 |
+
reye_pose=torch.zeros(n, 3, device=device),
|
| 117 |
+
)
|
| 118 |
+
joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
|
| 119 |
+
|
| 120 |
+
dt = 1 / pose_fps
|
| 121 |
+
# Compute linear velocity
|
| 122 |
+
init_vel = (joints[1:2] - joints[0:1]) / dt
|
| 123 |
+
middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
|
| 124 |
+
final_vel = (joints[-1:] - joints[-2:-1]) / dt
|
| 125 |
+
vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
|
| 126 |
+
|
| 127 |
+
position = joints
|
| 128 |
+
|
| 129 |
+
# Compute rotation 6D from axis-angle
|
| 130 |
+
poses_ts_reshaped_aa = poses_ts.reshape(1, n, 55, 3)
|
| 131 |
+
rot_matrices = rc.axis_angle_to_matrix(poses_ts_reshaped_aa)[0] # (n, 55, 3, 3)
|
| 132 |
+
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()
|
| 133 |
+
|
| 134 |
+
# Compute angular velocity
|
| 135 |
+
init_vel_ang = (poses_np[1:2] - poses_np[0:1]) / dt
|
| 136 |
+
middle_vel_ang = (poses_np[2:] - poses_np[:-2]) / (2 * dt)
|
| 137 |
+
final_vel_ang = (poses_np[-1:] - poses_np[-2:-1]) / dt
|
| 138 |
+
angular_velocity = np.concatenate([init_vel_ang, middle_vel_ang, final_vel_ang], axis=0).reshape(n, 55, 3)
|
| 139 |
+
|
| 140 |
+
# rep15d: position(55*3), vel(55*3), rot6d(55*6), angular_velocity(55*3) => total 55*(3+3+6+3)=55*15
|
| 141 |
+
rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15)
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"position": position,
|
| 145 |
+
"velocity": vel,
|
| 146 |
+
"rotation": rot6d,
|
| 147 |
+
"axis_angle": poses_np,
|
| 148 |
+
"angular_velocity": angular_velocity,
|
| 149 |
+
"rep15d": rep15d,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def process_smplx_motion(pose_file, smplx_model, pose_fps, facial_rep=None):
|
| 153 |
+
"""Process SMPLX pose and facial data together."""
|
| 154 |
+
pose_data = np.load(pose_file, allow_pickle=True)
|
| 155 |
+
stride = int(30/pose_fps)
|
| 156 |
+
|
| 157 |
+
# Extract pose and facial data with same stride
|
| 158 |
+
pose_frames = pose_data["poses"][::stride]
|
| 159 |
+
facial_frames = pose_data["expressions"][::stride] if facial_rep is not None else None
|
| 160 |
+
|
| 161 |
+
# Process translations
|
| 162 |
+
trans = pose_data["trans"][::stride]
|
| 163 |
+
trans[:,0] = trans[:,0] - trans[0,0]
|
| 164 |
+
trans[:,2] = trans[:,2] - trans[0,2]
|
| 165 |
+
|
| 166 |
+
# Calculate translation velocities
|
| 167 |
+
trans_v = np.zeros_like(trans)
|
| 168 |
+
trans_v[1:,0] = trans[1:,0] - trans[:-1,0]
|
| 169 |
+
trans_v[0,0] = trans_v[1,0]
|
| 170 |
+
trans_v[1:,2] = trans[1:,2] - trans[:-1,2]
|
| 171 |
+
trans_v[0,2] = trans_v[1,2]
|
| 172 |
+
trans_v[:,1] = trans[:,1]
|
| 173 |
+
|
| 174 |
+
# Process shape data
|
| 175 |
+
shape = np.repeat(pose_data["betas"].reshape(1, 300), pose_frames.shape[0], axis=0)
|
| 176 |
+
|
| 177 |
+
# # Calculate contacts
|
| 178 |
+
# contacts = calculate_foot_contacts(pose_data, smplx_model)
|
| 179 |
+
|
| 180 |
+
# if contacts is not None:
|
| 181 |
+
# pose_data = np.concatenate([pose_data, contacts], axis=1)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
'pose': pose_frames,
|
| 185 |
+
'trans': trans,
|
| 186 |
+
'trans_v': trans_v,
|
| 187 |
+
'shape': shape,
|
| 188 |
+
'facial': facial_frames if facial_frames is not None else np.array([-1])
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def calculate_foot_contacts(pose_data, smplx_model):
|
| 192 |
+
"""Calculate foot contacts from pose data."""
|
| 193 |
+
max_length = 128
|
| 194 |
+
all_tensor = []
|
| 195 |
+
n = pose_data["poses"].shape[0]
|
| 196 |
+
|
| 197 |
+
# Process in batches
|
| 198 |
+
for i in range(n // max_length):
|
| 199 |
+
joints = process_joints_batch(pose_data, i, max_length, smplx_model)
|
| 200 |
+
all_tensor.append(joints)
|
| 201 |
+
|
| 202 |
+
# Process remaining frames
|
| 203 |
+
if n % max_length != 0:
|
| 204 |
+
r = n % max_length
|
| 205 |
+
joints = process_joints_batch(pose_data, n // max_length, r, smplx_model, remainder=True)
|
| 206 |
+
all_tensor.append(joints)
|
| 207 |
+
|
| 208 |
+
# Calculate velocities and contacts
|
| 209 |
+
joints = torch.cat(all_tensor, axis=0)
|
| 210 |
+
feetv = torch.zeros(joints.shape[1], joints.shape[0])
|
| 211 |
+
joints = joints.permute(1, 0, 2)
|
| 212 |
+
feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1)
|
| 213 |
+
contacts = (feetv < 0.01).numpy().astype(float)
|
| 214 |
+
|
| 215 |
+
return contacts.transpose(1, 0)
|
| 216 |
+
|
| 217 |
+
def process_joints_batch(pose_data, batch_idx, batch_size, smplx_model, remainder=False):
|
| 218 |
+
"""Process a batch of joints for contact calculation."""
|
| 219 |
+
start_idx = batch_idx * batch_size
|
| 220 |
+
end_idx = start_idx + batch_size
|
| 221 |
+
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
return smplx_model(
|
| 224 |
+
betas=torch.from_numpy(pose_data["betas"]).cuda().float().repeat(batch_size, 1),
|
| 225 |
+
transl=torch.from_numpy(pose_data["trans"][start_idx:end_idx]).cuda().float(),
|
| 226 |
+
expression=torch.from_numpy(pose_data["expressions"][start_idx:end_idx]).cuda().float(),
|
| 227 |
+
jaw_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 66:69]).cuda().float(),
|
| 228 |
+
global_orient=torch.from_numpy(pose_data["poses"][start_idx:end_idx, :3]).cuda().float(),
|
| 229 |
+
body_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 3:21*3+3]).cuda().float(),
|
| 230 |
+
left_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 25*3:40*3]).cuda().float(),
|
| 231 |
+
right_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 40*3:55*3]).cuda().float(),
|
| 232 |
+
leye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 69:72]).cuda().float(),
|
| 233 |
+
reye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 72:75]).cuda().float(),
|
| 234 |
+
return_verts=True,
|
| 235 |
+
return_joints=True
|
| 236 |
+
)['joints'][:, (7,8,10,11), :].reshape(batch_size, 4, 3).cpu()
|
dataloaders/utils/other_tools.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import shutil
|
| 6 |
+
import csv
|
| 7 |
+
import pprint
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import pickle
|
| 13 |
+
import time
|
| 14 |
+
import lmdb
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
def adjust_array(x, k):
|
| 18 |
+
len_x = len(x)
|
| 19 |
+
len_k = len(k)
|
| 20 |
+
|
| 21 |
+
# If x is shorter than k, pad with zeros
|
| 22 |
+
if len_x < len_k:
|
| 23 |
+
return np.pad(x, (0, len_k - len_x), 'constant')
|
| 24 |
+
|
| 25 |
+
# If x is longer than k, truncate x
|
| 26 |
+
elif len_x > len_k:
|
| 27 |
+
return x[:len_k]
|
| 28 |
+
|
| 29 |
+
# If both are of same length
|
| 30 |
+
else:
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
def onset_to_frame(onset_times, audio_length, fps):
|
| 34 |
+
# Calculate total number of frames for the given audio length
|
| 35 |
+
total_frames = int(audio_length * fps)
|
| 36 |
+
|
| 37 |
+
# Create an array of zeros of shape (total_frames,)
|
| 38 |
+
frame_array = np.zeros(total_frames, dtype=np.int32)
|
| 39 |
+
|
| 40 |
+
# For each onset time, calculate the frame number and set it to 1
|
| 41 |
+
for onset in onset_times:
|
| 42 |
+
frame_num = int(onset * fps)
|
| 43 |
+
# Check if the frame number is within the array bounds
|
| 44 |
+
if 0 <= frame_num < total_frames:
|
| 45 |
+
frame_array[frame_num] = 1
|
| 46 |
+
|
| 47 |
+
return frame_array
|
| 48 |
+
|
| 49 |
+
def smooth_animations(animation1, animation2, blend_frames):
|
| 50 |
+
"""
|
| 51 |
+
Smoothly transition between two animation clips using linear interpolation.
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
- animation1: The first animation clip, a numpy array of shape [n, k].
|
| 55 |
+
- animation2: The second animation clip, a numpy array of shape [n, k].
|
| 56 |
+
- blend_frames: Number of frames over which to blend the two animations.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
- A smoothly blended animation clip of shape [2n, k].
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Ensure blend_frames doesn't exceed the length of either animation
|
| 63 |
+
blend_frames = min(blend_frames, len(animation1), len(animation2))
|
| 64 |
+
|
| 65 |
+
# Extract overlapping sections
|
| 66 |
+
overlap_a1 = animation1[-blend_frames:-blend_frames+1, :]
|
| 67 |
+
overlap_a2 = animation2[blend_frames-1:blend_frames, :]
|
| 68 |
+
|
| 69 |
+
# Create blend weights for linear interpolation
|
| 70 |
+
alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1)
|
| 71 |
+
|
| 72 |
+
# Linearly interpolate between overlapping sections
|
| 73 |
+
blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha
|
| 74 |
+
|
| 75 |
+
# Extend the animations to form the result with 2n frames
|
| 76 |
+
if blend_frames == len(animation1) and blend_frames == len(animation2):
|
| 77 |
+
result = blended_overlap
|
| 78 |
+
else:
|
| 79 |
+
before_blend = animation1[:-blend_frames]
|
| 80 |
+
after_blend = animation2[blend_frames:]
|
| 81 |
+
result = np.vstack((before_blend, blended_overlap, after_blend))
|
| 82 |
+
return result
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def interpolate_sequence(quaternions):
|
| 86 |
+
bs, n, j, _ = quaternions.shape
|
| 87 |
+
new_n = 2 * n
|
| 88 |
+
new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype)
|
| 89 |
+
|
| 90 |
+
for i in range(n):
|
| 91 |
+
q1 = quaternions[:, i, :, :]
|
| 92 |
+
new_quaternions[:, 2*i, :, :] = q1
|
| 93 |
+
|
| 94 |
+
if i < n - 1:
|
| 95 |
+
q2 = quaternions[:, i + 1, :, :]
|
| 96 |
+
new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5)
|
| 97 |
+
else:
|
| 98 |
+
# For the last point, duplicate the value
|
| 99 |
+
new_quaternions[:, 2*i + 1, :, :] = q1
|
| 100 |
+
|
| 101 |
+
return new_quaternions
|
| 102 |
+
|
| 103 |
+
def quaternion_multiply(q1, q2):
|
| 104 |
+
w1, x1, y1, z1 = q1
|
| 105 |
+
w2, x2, y2, z2 = q2
|
| 106 |
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
| 107 |
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
| 108 |
+
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
|
| 109 |
+
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
|
| 110 |
+
return w, x, y, z
|
| 111 |
+
|
| 112 |
+
def quaternion_conjugate(q):
|
| 113 |
+
w, x, y, z = q
|
| 114 |
+
return (w, -x, -y, -z)
|
| 115 |
+
|
| 116 |
+
def slerp(q1, q2, t):
|
| 117 |
+
dot = torch.sum(q1 * q2, dim=-1, keepdim=True)
|
| 118 |
+
|
| 119 |
+
flip = (dot < 0).float()
|
| 120 |
+
q2 = (1 - flip * 2) * q2
|
| 121 |
+
dot = dot * (1 - flip * 2)
|
| 122 |
+
|
| 123 |
+
DOT_THRESHOLD = 0.9995
|
| 124 |
+
mask = (dot > DOT_THRESHOLD).float()
|
| 125 |
+
|
| 126 |
+
theta_0 = torch.acos(dot)
|
| 127 |
+
theta = theta_0 * t
|
| 128 |
+
|
| 129 |
+
q3 = q2 - q1 * dot
|
| 130 |
+
q3 = q3 / torch.norm(q3, dim=-1, keepdim=True)
|
| 131 |
+
|
| 132 |
+
interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3)
|
| 133 |
+
|
| 134 |
+
return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated
|
| 135 |
+
|
| 136 |
+
def estimate_linear_velocity(data_seq, dt):
|
| 137 |
+
'''
|
| 138 |
+
Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates
|
| 139 |
+
the velocity for the middle T-2 steps using a second order central difference scheme.
|
| 140 |
+
The first and last frames are with forward and backward first-order
|
| 141 |
+
differences, respectively
|
| 142 |
+
- h : step size
|
| 143 |
+
'''
|
| 144 |
+
# first steps is forward diff (t+1 - t) / dt
|
| 145 |
+
init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt
|
| 146 |
+
# middle steps are second order (t+1 - t-1) / 2dt
|
| 147 |
+
middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt)
|
| 148 |
+
# last step is backward diff (t - t-1) / dt
|
| 149 |
+
final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt
|
| 150 |
+
|
| 151 |
+
vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1)
|
| 152 |
+
return vel_seq
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def estimate_angular_velocity(rot_seq, dt):
|
| 156 |
+
'''
|
| 157 |
+
Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps.
|
| 158 |
+
Input sequence should be of shape (B, T, ..., 3, 3)
|
| 159 |
+
'''
|
| 160 |
+
# see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix
|
| 161 |
+
dRdt = estimate_linear_velocity(rot_seq, dt)
|
| 162 |
+
R = rot_seq
|
| 163 |
+
RT = R.transpose(-1, -2)
|
| 164 |
+
# compute skew-symmetric angular velocity tensor
|
| 165 |
+
w_mat = torch.matmul(dRdt, RT)
|
| 166 |
+
# pull out angular velocity vector by averaging symmetric entries
|
| 167 |
+
w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0
|
| 168 |
+
w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0
|
| 169 |
+
w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0
|
| 170 |
+
w = torch.stack([w_x, w_y, w_z], axis=-1)
|
| 171 |
+
return w
|
| 172 |
+
|
| 173 |
+
import matplotlib.image as mpimg
|
| 174 |
+
from io import BytesIO
|
| 175 |
+
|
| 176 |
+
def image_from_bytes(image_bytes):
|
| 177 |
+
return mpimg.imread(BytesIO(image_bytes), format='PNG')
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1):
|
| 182 |
+
import matplotlib
|
| 183 |
+
matplotlib.use('Agg')
|
| 184 |
+
import matplotlib.pyplot as plt
|
| 185 |
+
import trimesh
|
| 186 |
+
import pyvirtualdisplay as Display
|
| 187 |
+
|
| 188 |
+
vertices = vertices_all[i]
|
| 189 |
+
vertices1 = vertices1_all[i]
|
| 190 |
+
filename = f"{output_dir}frame_{i}.png"
|
| 191 |
+
filenames.append(filename)
|
| 192 |
+
if i%100 == 0:
|
| 193 |
+
print('processed', i, 'frames')
|
| 194 |
+
#time_s = time.time()
|
| 195 |
+
#print(vertices.shape)
|
| 196 |
+
if use_matplotlib:
|
| 197 |
+
fig = plt.figure(figsize=(20, 10))
|
| 198 |
+
ax = fig.add_subplot(121, projection="3d")
|
| 199 |
+
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
| 200 |
+
#ax.view_init(elev=0, azim=90)
|
| 201 |
+
x = vertices[:, 0]
|
| 202 |
+
y = vertices[:, 1]
|
| 203 |
+
z = vertices[:, 2]
|
| 204 |
+
ax.scatter(x, y, z, s=0.5)
|
| 205 |
+
ax.set_xlim([-1.0, 1.0])
|
| 206 |
+
ax.set_ylim([-0.5, 1.5])#heigth
|
| 207 |
+
ax.set_zlim([-0, 2])#depth
|
| 208 |
+
ax.set_box_aspect((1,1,1))
|
| 209 |
+
else:
|
| 210 |
+
mesh = trimesh.Trimesh(vertices, faces)
|
| 211 |
+
scene = mesh.scene()
|
| 212 |
+
scene.camera.fov = camera_params['fov']
|
| 213 |
+
scene.camera.resolution = camera_params['resolution']
|
| 214 |
+
scene.camera.z_near = camera_params['z_near']
|
| 215 |
+
scene.camera.z_far = camera_params['z_far']
|
| 216 |
+
scene.graph[scene.camera.name] = camera_params['transform']
|
| 217 |
+
fig, ax =plt.subplots(1,2, figsize=(16, 6))
|
| 218 |
+
image = scene.save_image(resolution=[640, 480], visible=False)
|
| 219 |
+
im0 = ax[0].imshow(image_from_bytes(image))
|
| 220 |
+
ax[0].axis('off')
|
| 221 |
+
|
| 222 |
+
if use_matplotlib:
|
| 223 |
+
ax2 = fig.add_subplot(122, projection="3d")
|
| 224 |
+
ax2.set_box_aspect((1,1,1))
|
| 225 |
+
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
| 226 |
+
x1 = vertices1[:, 0]
|
| 227 |
+
y1 = vertices1[:, 1]
|
| 228 |
+
z1 = vertices1[:, 2]
|
| 229 |
+
ax2.scatter(x1, y1, z1, s=0.5)
|
| 230 |
+
ax2.set_xlim([-1.0, 1.0])
|
| 231 |
+
ax2.set_ylim([-0.5, 1.5])#heigth
|
| 232 |
+
ax2.set_zlim([-0, 2])
|
| 233 |
+
plt.savefig(filename, bbox_inches='tight')
|
| 234 |
+
plt.close(fig)
|
| 235 |
+
else:
|
| 236 |
+
mesh1 = trimesh.Trimesh(vertices1, faces)
|
| 237 |
+
scene1 = mesh1.scene()
|
| 238 |
+
scene1.camera.fov = camera_params1['fov']
|
| 239 |
+
scene1.camera.resolution = camera_params1['resolution']
|
| 240 |
+
scene1.camera.z_near = camera_params1['z_near']
|
| 241 |
+
scene1.camera.z_far = camera_params1['z_far']
|
| 242 |
+
scene1.graph[scene1.camera.name] = camera_params1['transform']
|
| 243 |
+
image1 = scene1.save_image(resolution=[640, 480], visible=False)
|
| 244 |
+
im1 = ax[1].imshow(image_from_bytes(image1))
|
| 245 |
+
ax[1].axis('off')
|
| 246 |
+
plt.savefig(filename, bbox_inches='tight')
|
| 247 |
+
plt.close(fig)
|
| 248 |
+
|
| 249 |
+
def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames):
|
| 250 |
+
import multiprocessing
|
| 251 |
+
import trimesh
|
| 252 |
+
num_cores = multiprocessing.cpu_count() # This will get the number of cores on your machine.
|
| 253 |
+
mesh = trimesh.Trimesh(vertices_all[0], faces)
|
| 254 |
+
scene = mesh.scene()
|
| 255 |
+
camera_params = {
|
| 256 |
+
'fov': scene.camera.fov,
|
| 257 |
+
'resolution': scene.camera.resolution,
|
| 258 |
+
'focal': scene.camera.focal,
|
| 259 |
+
'z_near': scene.camera.z_near,
|
| 260 |
+
"z_far": scene.camera.z_far,
|
| 261 |
+
'transform': scene.graph[scene.camera.name][0]
|
| 262 |
+
}
|
| 263 |
+
mesh1 = trimesh.Trimesh(vertices1_all[0], faces)
|
| 264 |
+
scene1 = mesh1.scene()
|
| 265 |
+
camera_params1 = {
|
| 266 |
+
'fov': scene1.camera.fov,
|
| 267 |
+
'resolution': scene1.camera.resolution,
|
| 268 |
+
'focal': scene1.camera.focal,
|
| 269 |
+
'z_near': scene1.camera.z_near,
|
| 270 |
+
"z_far": scene1.camera.z_far,
|
| 271 |
+
'transform': scene1.graph[scene1.camera.name][0]
|
| 272 |
+
}
|
| 273 |
+
# Use a Pool to manage the processes
|
| 274 |
+
# print(num_cores)
|
| 275 |
+
progress = multiprocessing.Value('i', 0)
|
| 276 |
+
lock = multiprocessing.Lock()
|
| 277 |
+
with multiprocessing.Pool(num_cores) as pool:
|
| 278 |
+
pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)])
|
| 279 |
+
|
| 280 |
+
def render_one_sequence(
|
| 281 |
+
res_npz_path,
|
| 282 |
+
gt_npz_path,
|
| 283 |
+
output_dir,
|
| 284 |
+
audio_path,
|
| 285 |
+
model_folder="/data/datasets/smplx_models/",
|
| 286 |
+
model_type='smplx',
|
| 287 |
+
gender='NEUTRAL_2020',
|
| 288 |
+
ext='npz',
|
| 289 |
+
num_betas=300,
|
| 290 |
+
num_expression_coeffs=100,
|
| 291 |
+
use_face_contour=False,
|
| 292 |
+
use_matplotlib=False,
|
| 293 |
+
args=None):
|
| 294 |
+
import smplx
|
| 295 |
+
import matplotlib.pyplot as plt
|
| 296 |
+
import imageio
|
| 297 |
+
from tqdm import tqdm
|
| 298 |
+
import os
|
| 299 |
+
import numpy as np
|
| 300 |
+
import torch
|
| 301 |
+
import moviepy.editor as mp
|
| 302 |
+
import librosa
|
| 303 |
+
|
| 304 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 305 |
+
model = smplx.create(
|
| 306 |
+
model_folder,
|
| 307 |
+
model_type=model_type,
|
| 308 |
+
gender=gender,
|
| 309 |
+
use_face_contour=use_face_contour,
|
| 310 |
+
num_betas=num_betas,
|
| 311 |
+
num_expression_coeffs=num_expression_coeffs,
|
| 312 |
+
ext=ext,
|
| 313 |
+
use_pca=False,
|
| 314 |
+
).to(device)
|
| 315 |
+
|
| 316 |
+
#data_npz = np.load(f"{output_dir}{res_npz_path}.npz")
|
| 317 |
+
data_np_body = np.load(res_npz_path, allow_pickle=True)
|
| 318 |
+
gt_np_body = np.load(gt_npz_path, allow_pickle=True)
|
| 319 |
+
|
| 320 |
+
if not os.path.exists(output_dir): os.makedirs(output_dir)
|
| 321 |
+
filenames = []
|
| 322 |
+
if not use_matplotlib:
|
| 323 |
+
import trimesh
|
| 324 |
+
#import pyrender
|
| 325 |
+
from pyvirtualdisplay import Display
|
| 326 |
+
display = Display(visible=0, size=(640, 480))
|
| 327 |
+
display.start()
|
| 328 |
+
faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"]
|
| 329 |
+
seconds = 1
|
| 330 |
+
#data_npz["jaw_pose"].shape[0]
|
| 331 |
+
n = data_np_body["poses"].shape[0]
|
| 332 |
+
beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device)
|
| 333 |
+
beta = beta.repeat(n, 1)
|
| 334 |
+
expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device)
|
| 335 |
+
jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device)
|
| 336 |
+
pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device)
|
| 337 |
+
transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device)
|
| 338 |
+
# print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape)
|
| 339 |
+
output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose,
|
| 340 |
+
global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3],
|
| 341 |
+
leye_pose=pose[:, 69:72],
|
| 342 |
+
reye_pose=pose[:, 72:75],
|
| 343 |
+
return_verts=True)
|
| 344 |
+
vertices_all = output["vertices"].cpu().detach().numpy()
|
| 345 |
+
|
| 346 |
+
beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device)
|
| 347 |
+
expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).to(device)
|
| 348 |
+
jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).to(device)
|
| 349 |
+
pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).to(device)
|
| 350 |
+
transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).to(device)
|
| 351 |
+
output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3],
|
| 352 |
+
leye_pose=pose1[:, 69:72],
|
| 353 |
+
reye_pose=pose1[:, 72:75],return_verts=True)
|
| 354 |
+
vertices1_all = output1["vertices"].cpu().detach().numpy()
|
| 355 |
+
if args.debug:
|
| 356 |
+
seconds = 1
|
| 357 |
+
else:
|
| 358 |
+
seconds = vertices_all.shape[0]//30
|
| 359 |
+
# camera_settings = None
|
| 360 |
+
time_s = time.time()
|
| 361 |
+
generate_images(int(seconds*30), vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames)
|
| 362 |
+
filenames = [f"{output_dir}frame_{i}.png" for i in range(int(seconds*30))]
|
| 363 |
+
# print(time.time()-time_s)
|
| 364 |
+
# for i in tqdm(range(seconds*30)):
|
| 365 |
+
# vertices = vertices_all[i]
|
| 366 |
+
# vertices1 = vertices1_all[i]
|
| 367 |
+
# filename = f"{output_dir}frame_{i}.png"
|
| 368 |
+
# filenames.append(filename)
|
| 369 |
+
# #time_s = time.time()
|
| 370 |
+
# #print(vertices.shape)
|
| 371 |
+
# if use_matplotlib:
|
| 372 |
+
# fig = plt.figure(figsize=(20, 10))
|
| 373 |
+
# ax = fig.add_subplot(121, projection="3d")
|
| 374 |
+
# fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
| 375 |
+
# #ax.view_init(elev=0, azim=90)
|
| 376 |
+
# x = vertices[:, 0]
|
| 377 |
+
# y = vertices[:, 1]
|
| 378 |
+
# z = vertices[:, 2]
|
| 379 |
+
# ax.scatter(x, y, z, s=0.5)
|
| 380 |
+
# ax.set_xlim([-1.0, 1.0])
|
| 381 |
+
# ax.set_ylim([-0.5, 1.5])#heigth
|
| 382 |
+
# ax.set_zlim([-0, 2])#depth
|
| 383 |
+
# ax.set_box_aspect((1,1,1))
|
| 384 |
+
# else:
|
| 385 |
+
# mesh = trimesh.Trimesh(vertices, faces)
|
| 386 |
+
# if i == 0:
|
| 387 |
+
# scene = mesh.scene()
|
| 388 |
+
# camera_params = {
|
| 389 |
+
# 'fov': scene.camera.fov,
|
| 390 |
+
# 'resolution': scene.camera.resolution,
|
| 391 |
+
# 'focal': scene.camera.focal,
|
| 392 |
+
# 'z_near': scene.camera.z_near,
|
| 393 |
+
# "z_far": scene.camera.z_far,
|
| 394 |
+
# 'transform': scene.graph[scene.camera.name][0]
|
| 395 |
+
# }
|
| 396 |
+
# else:
|
| 397 |
+
# scene = mesh.scene()
|
| 398 |
+
# scene.camera.fov = camera_params['fov']
|
| 399 |
+
# scene.camera.resolution = camera_params['resolution']
|
| 400 |
+
# scene.camera.z_near = camera_params['z_near']
|
| 401 |
+
# scene.camera.z_far = camera_params['z_far']
|
| 402 |
+
# scene.graph[scene.camera.name] = camera_params['transform']
|
| 403 |
+
# fig, ax =plt.subplots(1,2, figsize=(16, 6))
|
| 404 |
+
# image = scene.save_image(resolution=[640, 480], visible=False)
|
| 405 |
+
# #print((time.time()-time_s))
|
| 406 |
+
# im0 = ax[0].imshow(image_from_bytes(image))
|
| 407 |
+
# ax[0].axis('off')
|
| 408 |
+
|
| 409 |
+
# # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0)
|
| 410 |
+
# # expression1 = torch.from_numpy(gt_np_body["expressions"][i]).to(torch.float32).unsqueeze(0)
|
| 411 |
+
# # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][i][66:69]).to(torch.float32).unsqueeze(0)
|
| 412 |
+
# # pose1 = torch.from_numpy(gt_np_body["poses"][i]).to(torch.float32).unsqueeze(0)
|
| 413 |
+
# # transl1 = torch.from_numpy(gt_np_body["trans"][i]).to(torch.float32).unsqueeze(0)
|
| 414 |
+
# # #print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape)global_orient=pose[0:1,:3],
|
| 415 |
+
# # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[0:1,:3], body_pose=pose1[0:1,3:21*3+3], left_hand_pose=pose1[0:1,25*3:40*3], right_hand_pose=pose1[0:1,40*3:55*3], return_verts=True)
|
| 416 |
+
# # vertices1 = output1["vertices"].cpu().detach().numpy()[0]
|
| 417 |
+
|
| 418 |
+
# if use_matplotlib:
|
| 419 |
+
# ax2 = fig.add_subplot(122, projection="3d")
|
| 420 |
+
# ax2.set_box_aspect((1,1,1))
|
| 421 |
+
# fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
| 422 |
+
# #ax2.view_init(elev=0, azim=90)
|
| 423 |
+
# x1 = vertices1[:, 0]
|
| 424 |
+
# y1 = vertices1[:, 1]
|
| 425 |
+
# z1 = vertices1[:, 2]
|
| 426 |
+
# ax2.scatter(x1, y1, z1, s=0.5)
|
| 427 |
+
# ax2.set_xlim([-1.0, 1.0])
|
| 428 |
+
# ax2.set_ylim([-0.5, 1.5])#heigth
|
| 429 |
+
# ax2.set_zlim([-0, 2])
|
| 430 |
+
# plt.savefig(filename, bbox_inches='tight')
|
| 431 |
+
# plt.close(fig)
|
| 432 |
+
# else:
|
| 433 |
+
# mesh1 = trimesh.Trimesh(vertices1, faces)
|
| 434 |
+
# if i == 0:
|
| 435 |
+
# scene1 = mesh1.scene()
|
| 436 |
+
# camera_params1 = {
|
| 437 |
+
# 'fov': scene1.camera.fov,
|
| 438 |
+
# 'resolution': scene1.camera.resolution,
|
| 439 |
+
# 'focal': scene1.camera.focal,
|
| 440 |
+
# 'z_near': scene1.camera.z_near,
|
| 441 |
+
# "z_far": scene1.camera.z_far,
|
| 442 |
+
# 'transform': scene1.graph[scene1.camera.name][0]
|
| 443 |
+
# }
|
| 444 |
+
# else:
|
| 445 |
+
# scene1 = mesh1.scene()
|
| 446 |
+
# scene1.camera.fov = camera_params1['fov']
|
| 447 |
+
# scene1.camera.resolution = camera_params1['resolution']
|
| 448 |
+
# scene1.camera.z_near = camera_params1['z_near']
|
| 449 |
+
# scene1.camera.z_far = camera_params1['z_far']
|
| 450 |
+
# scene1.graph[scene1.camera.name] = camera_params1['transform']
|
| 451 |
+
# image1 = scene1.save_image(resolution=[640, 480], visible=False)
|
| 452 |
+
# im1 = ax[1].imshow(image_from_bytes(image1))
|
| 453 |
+
# ax[1].axis('off')
|
| 454 |
+
# plt.savefig(filename, bbox_inches='tight')
|
| 455 |
+
# plt.close(fig)
|
| 456 |
+
|
| 457 |
+
# display.stop()
|
| 458 |
+
# print(filenames)
|
| 459 |
+
images = [imageio.imread(filename) for filename in filenames]
|
| 460 |
+
imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=30)
|
| 461 |
+
for filename in filenames:
|
| 462 |
+
os.remove(filename)
|
| 463 |
+
|
| 464 |
+
video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4")
|
| 465 |
+
# audio, sr = librosa.load(audio_path)
|
| 466 |
+
# audio = audio[:seconds*sr]
|
| 467 |
+
# print(audio.shape, seconds, sr)
|
| 468 |
+
# import soundfile as sf
|
| 469 |
+
# sf.write(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, 16000, 'PCM_24')
|
| 470 |
+
# audio_tmp = librosa.output.write_wav(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, sr=16000)
|
| 471 |
+
audio = mp.AudioFileClip(audio_path)
|
| 472 |
+
if audio.duration > video.duration:
|
| 473 |
+
audio = audio.subclip(0, video.duration)
|
| 474 |
+
final_clip = video.set_audio(audio)
|
| 475 |
+
final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4")
|
| 476 |
+
os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4")
|
| 477 |
+
|
| 478 |
+
def print_exp_info(args):
|
| 479 |
+
logger.info(pprint.pformat(vars(args)))
|
| 480 |
+
logger.info(f"# ------------ {args.name} ----------- #")
|
| 481 |
+
logger.info("PyTorch version: {}".format(torch.__version__))
|
| 482 |
+
logger.info("CUDA version: {}".format(torch.version.cuda))
|
| 483 |
+
logger.info("{} GPUs".format(torch.cuda.device_count()))
|
| 484 |
+
logger.info(f"Random Seed: {args.random_seed}")
|
| 485 |
+
|
| 486 |
+
def args2csv(args, get_head=False, list4print=[]):
|
| 487 |
+
for k, v in args.items():
|
| 488 |
+
if isinstance(args[k], dict):
|
| 489 |
+
args2csv(args[k], get_head, list4print)
|
| 490 |
+
else: list4print.append(k) if get_head else list4print.append(v)
|
| 491 |
+
return list4print
|
| 492 |
+
|
| 493 |
+
class EpochTracker:
|
| 494 |
+
def __init__(self, metric_names, metric_directions):
|
| 495 |
+
assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length"
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
self.metric_names = metric_names
|
| 499 |
+
self.states = ['train', 'val', 'test']
|
| 500 |
+
self.types = ['last', 'best']
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0}
|
| 504 |
+
for type_ in self.types}
|
| 505 |
+
for state in self.states}
|
| 506 |
+
for name, is_higher_better in zip(metric_names, metric_directions)}
|
| 507 |
+
|
| 508 |
+
self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}")
|
| 509 |
+
for state in self.states}
|
| 510 |
+
for name in metric_names}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)}
|
| 514 |
+
self.train_history = {name: [] for name in metric_names}
|
| 515 |
+
self.val_history = {name: [] for name in metric_names}
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def update_meter(self, name, state, value):
|
| 519 |
+
self.loss_meters[name][state].update(value)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def update_values(self, name, state, epoch):
|
| 523 |
+
value_avg = self.loss_meters[name][state].avg
|
| 524 |
+
new_best = False
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or
|
| 528 |
+
(value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])):
|
| 529 |
+
self.values[name][state]['best']['value'] = value_avg
|
| 530 |
+
self.values[name][state]['best']['epoch'] = epoch
|
| 531 |
+
new_best = True
|
| 532 |
+
self.values[name][state]['last']['value'] = value_avg
|
| 533 |
+
self.values[name][state]['last']['epoch'] = epoch
|
| 534 |
+
return new_best
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def get(self, name, state, type_):
|
| 538 |
+
return self.values[name][state][type_]
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def reset(self):
|
| 542 |
+
for name in self.metric_names:
|
| 543 |
+
for state in self.states:
|
| 544 |
+
self.loss_meters[name][state].reset()
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def flatten_values(self):
|
| 548 |
+
flat_dict = {}
|
| 549 |
+
for name in self.metric_names:
|
| 550 |
+
for state in self.states:
|
| 551 |
+
for type_ in self.types:
|
| 552 |
+
value_key = f"{name}_{state}_{type_}"
|
| 553 |
+
epoch_key = f"{name}_{state}_{type_}_epoch"
|
| 554 |
+
flat_dict[value_key] = self.values[name][state][type_]['value']
|
| 555 |
+
flat_dict[epoch_key] = self.values[name][state][type_]['epoch']
|
| 556 |
+
return flat_dict
|
| 557 |
+
|
| 558 |
+
def update_and_plot(self, name, epoch, save_path):
|
| 559 |
+
new_best_train = self.update_values(name, 'train', epoch)
|
| 560 |
+
new_best_val = self.update_values(name, 'val', epoch)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
self.train_history[name].append(self.loss_meters[name]['train'].avg)
|
| 564 |
+
self.val_history[name].append(self.loss_meters[name]['val'].avg)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
train_values = self.train_history[name]
|
| 568 |
+
val_values = self.val_history[name]
|
| 569 |
+
epochs = list(range(1, len(train_values) + 1))
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
plt.figure(figsize=(10, 6))
|
| 573 |
+
plt.plot(epochs, train_values, label='Train')
|
| 574 |
+
plt.plot(epochs, val_values, label='Val')
|
| 575 |
+
plt.title(f'Train vs Val {name} over epochs')
|
| 576 |
+
plt.xlabel('Epochs')
|
| 577 |
+
plt.ylabel(name)
|
| 578 |
+
plt.legend()
|
| 579 |
+
plt.savefig(save_path)
|
| 580 |
+
plt.close()
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
return new_best_train, new_best_val
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def record_trial(args, tracker):
|
| 589 |
+
"""
|
| 590 |
+
1. record notes, score, env_name, experments_path,
|
| 591 |
+
"""
|
| 592 |
+
csv_path = args.out_path + "custom/" +args.csv_name+".csv"
|
| 593 |
+
all_print_dict = vars(args)
|
| 594 |
+
all_print_dict.update(tracker.flatten_values())
|
| 595 |
+
if not os.path.exists(csv_path):
|
| 596 |
+
pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False)
|
| 597 |
+
else:
|
| 598 |
+
df_existing = pd.read_csv(csv_path)
|
| 599 |
+
df_new = pd.DataFrame([all_print_dict])
|
| 600 |
+
df_aligned = df_existing.append(df_new).fillna("")
|
| 601 |
+
df_aligned.to_csv(csv_path, index=False)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def set_random_seed(args):
|
| 605 |
+
os.environ['PYTHONHASHSEED'] = str(args.random_seed)
|
| 606 |
+
random.seed(args.random_seed)
|
| 607 |
+
np.random.seed(args.random_seed)
|
| 608 |
+
torch.manual_seed(args.random_seed)
|
| 609 |
+
torch.cuda.manual_seed_all(args.random_seed)
|
| 610 |
+
torch.cuda.manual_seed(args.random_seed)
|
| 611 |
+
torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC
|
| 612 |
+
torch.backends.cudnn.benchmark = args.benchmark
|
| 613 |
+
torch.backends.cudnn.enabled = args.cudnn_enabled
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None):
|
| 617 |
+
if lrs is not None:
|
| 618 |
+
states = { 'model_state': model.state_dict(),
|
| 619 |
+
'epoch': epoch + 1,
|
| 620 |
+
'opt_state': opt.state_dict(),
|
| 621 |
+
'lrs':lrs.state_dict(),}
|
| 622 |
+
elif opt is not None:
|
| 623 |
+
states = { 'model_state': model.state_dict(),
|
| 624 |
+
'epoch': epoch + 1,
|
| 625 |
+
'opt_state': opt.state_dict(),}
|
| 626 |
+
else:
|
| 627 |
+
states = { 'model_state': model.state_dict(),}
|
| 628 |
+
torch.save(states, save_path)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def load_checkpoints(model, save_path, load_name='model'):
|
| 632 |
+
states = torch.load(save_path)
|
| 633 |
+
new_weights = OrderedDict()
|
| 634 |
+
flag=False
|
| 635 |
+
for k, v in states['model_state'].items():
|
| 636 |
+
#print(k)
|
| 637 |
+
if "module" not in k:
|
| 638 |
+
break
|
| 639 |
+
else:
|
| 640 |
+
new_weights[k[7:]]=v
|
| 641 |
+
flag=True
|
| 642 |
+
if flag:
|
| 643 |
+
try:
|
| 644 |
+
model.load_state_dict(new_weights)
|
| 645 |
+
except:
|
| 646 |
+
#print(states['model_state'])
|
| 647 |
+
model.load_state_dict(states['model_state'])
|
| 648 |
+
else:
|
| 649 |
+
model.load_state_dict(states['model_state'])
|
| 650 |
+
logger.info(f"load self-pretrained checkpoints for {load_name}")
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def model_complexity(model, args):
|
| 654 |
+
from ptflops import get_model_complexity_info
|
| 655 |
+
flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN),
|
| 656 |
+
as_strings=False, print_per_layer_stat=False)
|
| 657 |
+
logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9))
|
| 658 |
+
logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6))
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class AverageMeter(object):
|
| 662 |
+
"""Computes and stores the average and current value"""
|
| 663 |
+
def __init__(self, name, fmt=':f'):
|
| 664 |
+
self.name = name
|
| 665 |
+
self.fmt = fmt
|
| 666 |
+
self.reset()
|
| 667 |
+
|
| 668 |
+
def reset(self):
|
| 669 |
+
self.val = 0
|
| 670 |
+
self.avg = 0
|
| 671 |
+
self.sum = 0
|
| 672 |
+
self.count = 0
|
| 673 |
+
|
| 674 |
+
def update(self, val, n=1):
|
| 675 |
+
self.val = val
|
| 676 |
+
self.sum += val * n
|
| 677 |
+
self.count += n
|
| 678 |
+
self.avg = self.sum / self.count
|
| 679 |
+
|
| 680 |
+
def __str__(self):
|
| 681 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
| 682 |
+
return fmtstr.format(**self.__dict__)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class MultiLMDBManager:
|
| 687 |
+
def __init__(self, base_dir, max_db_size=10*1024*1024*1024): # 10GB default size
|
| 688 |
+
self.base_dir = base_dir
|
| 689 |
+
self.max_db_size = max_db_size
|
| 690 |
+
self.current_db_size = 0
|
| 691 |
+
self.current_db_idx = 0
|
| 692 |
+
self.current_lmdb_env = None
|
| 693 |
+
self.sample_to_db_mapping = {}
|
| 694 |
+
self.sample_counter = 0
|
| 695 |
+
self.db_paths = []
|
| 696 |
+
|
| 697 |
+
def get_new_lmdb_path(self):
|
| 698 |
+
db_path = os.path.join(self.base_dir, f"db_{self.current_db_idx:03d}")
|
| 699 |
+
self.db_paths.append(db_path)
|
| 700 |
+
return db_path
|
| 701 |
+
|
| 702 |
+
def init_new_db(self):
|
| 703 |
+
if self.current_lmdb_env is not None:
|
| 704 |
+
self.current_lmdb_env.sync()
|
| 705 |
+
self.current_lmdb_env.close()
|
| 706 |
+
|
| 707 |
+
new_db_path = self.get_new_lmdb_path()
|
| 708 |
+
self.current_lmdb_env = lmdb.open(new_db_path, map_size=self.max_db_size)
|
| 709 |
+
self.current_db_size = 0
|
| 710 |
+
self.current_db_idx += 1
|
| 711 |
+
return self.current_lmdb_env
|
| 712 |
+
|
| 713 |
+
def add_sample(self, sample_data):
|
| 714 |
+
if self.current_lmdb_env is None:
|
| 715 |
+
self.init_new_db()
|
| 716 |
+
|
| 717 |
+
v = pickle.dumps(sample_data)
|
| 718 |
+
sample_size = len(v)
|
| 719 |
+
|
| 720 |
+
try:
|
| 721 |
+
sample_key = "{:008d}".format(self.sample_counter).encode("ascii")
|
| 722 |
+
with self.current_lmdb_env.begin(write=True) as txn:
|
| 723 |
+
txn.put(sample_key, v)
|
| 724 |
+
self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1
|
| 725 |
+
|
| 726 |
+
except lmdb.MapFullError:
|
| 727 |
+
self.init_new_db()
|
| 728 |
+
sample_key = "{:008d}".format(self.sample_counter).encode("ascii")
|
| 729 |
+
with self.current_lmdb_env.begin(write=True) as txn:
|
| 730 |
+
txn.put(sample_key, v)
|
| 731 |
+
self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1
|
| 732 |
+
|
| 733 |
+
self.current_db_size += sample_size
|
| 734 |
+
self.sample_counter += 1
|
| 735 |
+
|
| 736 |
+
def save_mapping(self):
|
| 737 |
+
mapping_path = os.path.join(self.base_dir, "sample_db_mapping.pkl")
|
| 738 |
+
with open(mapping_path, 'wb') as f:
|
| 739 |
+
pickle.dump({
|
| 740 |
+
'mapping': self.sample_to_db_mapping,
|
| 741 |
+
'db_paths': self.db_paths
|
| 742 |
+
}, f)
|
| 743 |
+
|
| 744 |
+
def close(self):
|
| 745 |
+
if self.current_lmdb_env is not None:
|
| 746 |
+
self.current_lmdb_env.sync()
|
| 747 |
+
self.current_lmdb_env.close()
|
| 748 |
+
self.save_mapping()
|