-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute-statistics.py
More file actions
51 lines (41 loc) · 1.53 KB
/
compute-statistics.py
File metadata and controls
51 lines (41 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import tqdm
import hydra
import torch
import importlib
import omegaconf
import numpy as np
from src.constants import (
DEFAULT_HYDRA_CONFIG_PATH,
DEFAULT_HYDRA_VERSION_BASE,
)
# type: ignore
@hydra.main(config_path=DEFAULT_HYDRA_CONFIG_PATH, config_name="compute-statistics", version_base=DEFAULT_HYDRA_VERSION_BASE)
def main(cfg: omegaconf.DictConfig):
dataset = hydra.utils.instantiate(
cfg.data,
split="train",
motion_normalizer=None
)
all_motion = {"new_joint_vecs": [], "new_joints": []}
for item in tqdm.tqdm(dataset, desc="[collecting-motion-stats]"):
motion = item.get("motion", None)
if motion is None:
continue
if isinstance(motion, dict):
if "new_joint_vecs" in motion:
all_motion["new_joint_vecs"].append(np.array(motion["new_joint_vecs"]))
if "new_joints" in motion:
all_motion["new_joints"].append(np.array(motion["new_joints"]))
else:
all_motion["new_joint_vecs"].append(np.array(motion))
stats = {"mean": {}, "std": {}}
for k, arrs in all_motion.items():
if arrs:
concat = np.concatenate([a.reshape(-1, a.shape[-1]) if a.ndim > 2 else a for a in arrs], axis=0)
stats["mean"][k] = concat.mean(axis=0)
stats["std"][k] = concat.std(axis=0) + 1e-8
torch.save(stats, cfg.output_path)
print(f"Saved stats to {os.path.abspath(cfg.output_path)}")
if __name__ == "__main__":
main()