-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvln_ce.py
More file actions
113 lines (86 loc) · 3.71 KB
/
vln_ce.py
File metadata and controls
113 lines (86 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
from argparse import ArgumentParser
from pathlib import Path
logging.basicConfig(level=logging.INFO)
import time
import json
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from utils.lerobot.lerobot_creater import LeRobotCreator
from utils import Trajectories
from utils.vln_ce.trajectory import VLN_CE_Trajectories
parser = ArgumentParser(description="Port VLN-CE dataset to LeRobotDataset format")
parser.add_argument("--raw_dir", type=str, default="InternData-n1-demo", help="Path to the raw VLN-CE dataset directory")
parser.add_argument("--output_dir", type=str, default=".", help="Path to the output LeRobotDataset directory")
parser.add_argument("--codec", type=str, default="h264", choices=["h264", "hevc", "libsvtav1"], help="Video codec to use for encoding")
parser.add_argument("--num_processes", type=int, default=8, help="Number of processes for image writing")
args = parser.parse_args()
def port(
raw_dir: str,
repo_id: str,
root: str,
traj_cls: type[Trajectories],
num_processes: int,
codec: str = "h264",
*args, **kwargs
):
"""Port raw dataset to LeRobotDataset format."""
logging.info(f"Porting raw dataset from {raw_dir} to LeRobotDataset repo {repo_id}")
# VLN_CE_Trajectories.FEATURES defines the schema
features = traj_cls.FEATURES
logging.info(f"Creating LeRobotCreator at {root}")
creator = LeRobotCreator(
root=str(root),
robot_type=traj_cls.ROBOT_TYPE,
fps=traj_cls.FPS,
features=features,
num_workers=max(1, num_processes),
num_video_encoders=int(max(1, num_processes) * 1.75), # more encoders than workers
codec=codec,
has_extras=True,
)
def get_task_idx(task: str) -> int:
return creator.add_task(task)
trajectories = traj_cls(raw_dir, get_task_idx=get_task_idx)
start_time = time.time()
num_episodes = len(trajectories)
logging.info(f"Number of episodes {num_episodes}")
for episode_index, episode in enumerate(trajectories):
creator.submit_episode(episode)
elapsed_time = time.time() - start_time
if (episode_index + 1) % 10 == 0:
logging.info('\033[92m' + f"Submitted {episode_index + 1} / {num_episodes} episodes (elapsed {elapsed_time:.3f} s)" + '\033[0m')
logging.info("All episodes submitted. Waiting for completion...")
creator.wait()
logging.info("LeRobotCreator finished.")
def validate_dataset(repo_id, root: str):
"""Sanity check that ensure meta data can be loaded and all files are present."""
meta = LeRobotDatasetMetadata(repo_id, root=root)
if meta.total_episodes == 0:
raise ValueError("Number of episodes is 0.")
for ep_idx in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(ep_idx)
if not data_path.exists():
raise ValueError(f"Parquet file is missing in: {data_path}")
for vid_key in meta.video_keys:
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
if not vid_path.exists():
raise ValueError(f"Video file is missing in: {vid_path}")
def main():
raw_dir = Path(args.raw_dir)
folder_name = raw_dir.name
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Use a distinct folder name for CE dataset
dataset_name = f"VLN-CE-{folder_name}"
root = output_dir / dataset_name
port(
raw_dir=raw_dir,
repo_id=dataset_name,
root=root,
traj_cls=VLN_CE_Trajectories,
num_processes=args.num_processes,
codec=args.codec,
)
validate_dataset(dataset_name, root=root)
if __name__ == "__main__":
main()