|
34 | 34 | "import os\n", |
35 | 35 | "import shutil\n", |
36 | 36 | "import random\n", |
37 | | - "import csv\n" |
| 37 | + "import csv" |
38 | 38 | ] |
39 | 39 | }, |
40 | 40 | { |
|
52 | 52 | } |
53 | 53 | ], |
54 | 54 | "source": [ |
55 | | - "\n", |
56 | 55 | "os.chdir(\"..\")\n", |
57 | | - "print(os.getcwd())\n" |
| 56 | + "print(os.getcwd())" |
58 | 57 | ] |
59 | 58 | }, |
60 | 59 | { |
|
98 | 97 | "frame_files = sorted(os.listdir(frames_dir))\n", |
99 | 98 | "\n", |
100 | 99 | "print(f\"Number of face files: {len(face_files)}\")\n", |
101 | | - "print(f\"Number of frame files: {len(frame_files)}\") " |
| 100 | + "print(f\"Number of frame files: {len(frame_files)}\")" |
102 | 101 | ] |
103 | 102 | }, |
104 | 103 | { |
|
147 | 146 | " train_frames = []\n", |
148 | 147 | " test_frames = []\n", |
149 | 148 | " train_face_set = set(face_id_from_filename(f) for f in train_faces)\n", |
150 | | - " \n", |
| 149 | + "\n", |
151 | 150 | " for frame in all_frames:\n", |
152 | 151 | "\n", |
153 | 152 | " face_id = face_id_from_filename(frame)\n", |
154 | 153 | " if face_id in train_face_set:\n", |
155 | 154 | " train_frames.append(frame)\n", |
156 | 155 | " else:\n", |
157 | 156 | " test_frames.append(frame)\n", |
158 | | - " \n", |
159 | | - " return train_frames, test_frames\n" |
| 157 | + "\n", |
| 158 | + " return train_frames, test_frames" |
160 | 159 | ] |
161 | 160 | }, |
162 | 161 | { |
|
290 | 289 | " src = os.path.join(frames_dir, f)\n", |
291 | 290 | " dst = os.path.join(train_frames_dir, f)\n", |
292 | 291 | " shutil.copyfile(src, dst)\n", |
293 | | - " \n", |
| 292 | + "\n", |
294 | 293 | " for f in te_frames:\n", |
295 | 294 | " src = os.path.join(frames_dir, f)\n", |
296 | 295 | " dst = os.path.join(test_frames_dir, f)\n", |
|
328 | 327 | "for ratio in split_ratios:\n", |
329 | 328 | " split_dir = os.path.join(splits_root, f\"split_{int(ratio*100)}\")\n", |
330 | 329 | " os.makedirs(split_dir, exist_ok=True)\n", |
331 | | - " create_partitioned_set(ratio, dataset_root, split_dir)\n" |
| 330 | + " create_partitioned_set(ratio, dataset_root, split_dir)" |
332 | 331 | ] |
333 | 332 | }, |
334 | 333 | { |
|
353 | 352 | " frame_files = sorted(os.listdir(frames_dir))\n", |
354 | 353 | " return face_files, frame_files\n", |
355 | 354 | "\n", |
| 355 | + "\n", |
356 | 356 | "def get_matches(face_id, frame_files, n_pairs):\n", |
357 | 357 | " matches = [f for f in frame_files if face_id_from_filename(f) == face_id]\n", |
358 | 358 | " return random.sample(matches, n_pairs)\n", |
359 | 359 | "\n", |
| 360 | + "\n", |
360 | 361 | "def get_mismatches(face_id, frame_files, n_pairs):\n", |
361 | 362 | " mismatches = [f for f in frame_files if face_id_from_filename(f) != face_id]\n", |
362 | 363 | " return random.sample(mismatches, n_pairs)\n", |
363 | 364 | "\n", |
| 365 | + "\n", |
364 | 366 | "def create_pairs(face_files, frame_files, n_pairs_per_face):\n", |
365 | 367 | " match_pairs = []\n", |
366 | 368 | " mismatch_pairs = []\n", |
|
403 | 405 | "print(f\"Number of faces: {len(faces)}\")\n", |
404 | 406 | "print(f\"Number of frames: {len(frames)}\")\n", |
405 | 407 | "\n", |
406 | | - "n_train = int(len(faces) * 2/3)\n", |
| 408 | + "n_train = int(len(faces) * 2 / 3)\n", |
407 | 409 | "n_val = len(faces) - n_train\n", |
408 | 410 | "print(f\"Number of training faces: {n_train}\")\n", |
409 | 411 | "print(f\"Number of validation faces: {n_val}\")\n", |
410 | 412 | "\n", |
411 | 413 | "train_faces = faces[:n_train]\n", |
412 | 414 | "val_faces = faces[n_train:]\n", |
413 | 415 | "\n", |
414 | | - "train_frames = frames[:n_train*31]\n", |
415 | | - "val_frames = frames[n_train*31:]\n", |
| 416 | + "train_frames = frames[: n_train * 31]\n", |
| 417 | + "val_frames = frames[n_train * 31 :]\n", |
416 | 418 | "\n", |
417 | 419 | "print(train_faces[-4:])\n", |
418 | 420 | "print(val_faces[:4])\n", |
|
1488 | 1490 | ], |
1489 | 1491 | "source": [ |
1490 | 1492 | "def save_csv(pairs, filepath):\n", |
1491 | | - " with open(filepath, mode='w', newline='') as file:\n", |
| 1493 | + " with open(filepath, mode=\"w\", newline=\"\") as file:\n", |
1492 | 1494 | " writer = csv.writer(file)\n", |
1493 | 1495 | " writer.writerow([\"face\", \"frame\"])\n", |
1494 | 1496 | " for face, frame in pairs:\n", |
1495 | 1497 | " writer.writerow([face, frame])\n", |
1496 | 1498 | " print(f\"Saved {len(pairs)} pairs to {filepath}\")\n", |
1497 | 1499 | "\n", |
| 1500 | + "\n", |
1498 | 1501 | "save_csv(match_train_pairs, \"data/rococo2v3-dev/train_match_pairs.csv\")\n", |
1499 | 1502 | "save_csv(mismatch_train_pairs, \"data/rococo2v3-dev/train_mismatch_pairs.csv\")\n", |
1500 | 1503 | "save_csv(match_val_pairs, \"data/rococo2v3-dev/val_match_pairs.csv\")\n", |
|
1529 | 1532 | } |
1530 | 1533 | ], |
1531 | 1534 | "source": [ |
1532 | | - "used_frames = set([\n", |
1533 | | - " *(frame for _, frame in match_train_pairs),\n", |
1534 | | - " *(frame for _, frame in mismatch_train_pairs),\n", |
1535 | | - " *(frame for _, frame in match_val_pairs),\n", |
1536 | | - " *(frame for _, frame in mismatch_val_pairs),\n", |
1537 | | - "])\n", |
| 1535 | + "used_frames = set(\n", |
| 1536 | + " [\n", |
| 1537 | + " *(frame for _, frame in match_train_pairs),\n", |
| 1538 | + " *(frame for _, frame in mismatch_train_pairs),\n", |
| 1539 | + " *(frame for _, frame in match_val_pairs),\n", |
| 1540 | + " *(frame for _, frame in mismatch_val_pairs),\n", |
| 1541 | + " ]\n", |
| 1542 | + ")\n", |
1538 | 1543 | "\n", |
1539 | | - "len(used_frames), sum((len(match_train_pairs), len(mismatch_train_pairs), len(match_val_pairs), len(mismatch_val_pairs)))" |
| 1544 | + "len(used_frames), sum(\n", |
| 1545 | + " (\n", |
| 1546 | + " len(match_train_pairs),\n", |
| 1547 | + " len(mismatch_train_pairs),\n", |
| 1548 | + " len(match_val_pairs),\n", |
| 1549 | + " len(mismatch_val_pairs),\n", |
| 1550 | + " )\n", |
| 1551 | + ")" |
1540 | 1552 | ] |
1541 | 1553 | }, |
1542 | 1554 | { |
|
0 commit comments