Skip to content

Commit 49697d7

Browse files
authored
Merge pull request #62 from mGarbowski/plots
Plots
2 parents f714b38 + 6fd1c80 commit 49697d7

13 files changed

Lines changed: 2057 additions & 37 deletions

notebooks/11-rococo-train-test-split.ipynb

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"import os\n",
3535
"import shutil\n",
3636
"import random\n",
37-
"import csv\n"
37+
"import csv"
3838
]
3939
},
4040
{
@@ -52,9 +52,8 @@
5252
}
5353
],
5454
"source": [
55-
"\n",
5655
"os.chdir(\"..\")\n",
57-
"print(os.getcwd())\n"
56+
"print(os.getcwd())"
5857
]
5958
},
6059
{
@@ -98,7 +97,7 @@
9897
"frame_files = sorted(os.listdir(frames_dir))\n",
9998
"\n",
10099
"print(f\"Number of face files: {len(face_files)}\")\n",
101-
"print(f\"Number of frame files: {len(frame_files)}\") "
100+
"print(f\"Number of frame files: {len(frame_files)}\")"
102101
]
103102
},
104103
{
@@ -147,16 +146,16 @@
147146
" train_frames = []\n",
148147
" test_frames = []\n",
149148
" train_face_set = set(face_id_from_filename(f) for f in train_faces)\n",
150-
" \n",
149+
"\n",
151150
" for frame in all_frames:\n",
152151
"\n",
153152
" face_id = face_id_from_filename(frame)\n",
154153
" if face_id in train_face_set:\n",
155154
" train_frames.append(frame)\n",
156155
" else:\n",
157156
" test_frames.append(frame)\n",
158-
" \n",
159-
" return train_frames, test_frames\n"
157+
"\n",
158+
" return train_frames, test_frames"
160159
]
161160
},
162161
{
@@ -290,7 +289,7 @@
290289
" src = os.path.join(frames_dir, f)\n",
291290
" dst = os.path.join(train_frames_dir, f)\n",
292291
" shutil.copyfile(src, dst)\n",
293-
" \n",
292+
"\n",
294293
" for f in te_frames:\n",
295294
" src = os.path.join(frames_dir, f)\n",
296295
" dst = os.path.join(test_frames_dir, f)\n",
@@ -328,7 +327,7 @@
328327
"for ratio in split_ratios:\n",
329328
" split_dir = os.path.join(splits_root, f\"split_{int(ratio*100)}\")\n",
330329
" os.makedirs(split_dir, exist_ok=True)\n",
331-
" create_partitioned_set(ratio, dataset_root, split_dir)\n"
330+
" create_partitioned_set(ratio, dataset_root, split_dir)"
332331
]
333332
},
334333
{
@@ -353,14 +352,17 @@
353352
" frame_files = sorted(os.listdir(frames_dir))\n",
354353
" return face_files, frame_files\n",
355354
"\n",
355+
"\n",
356356
"def get_matches(face_id, frame_files, n_pairs):\n",
357357
" matches = [f for f in frame_files if face_id_from_filename(f) == face_id]\n",
358358
" return random.sample(matches, n_pairs)\n",
359359
"\n",
360+
"\n",
360361
"def get_mismatches(face_id, frame_files, n_pairs):\n",
361362
" mismatches = [f for f in frame_files if face_id_from_filename(f) != face_id]\n",
362363
" return random.sample(mismatches, n_pairs)\n",
363364
"\n",
365+
"\n",
364366
"def create_pairs(face_files, frame_files, n_pairs_per_face):\n",
365367
" match_pairs = []\n",
366368
" mismatch_pairs = []\n",
@@ -403,16 +405,16 @@
403405
"print(f\"Number of faces: {len(faces)}\")\n",
404406
"print(f\"Number of frames: {len(frames)}\")\n",
405407
"\n",
406-
"n_train = int(len(faces) * 2/3)\n",
408+
"n_train = int(len(faces) * 2 / 3)\n",
407409
"n_val = len(faces) - n_train\n",
408410
"print(f\"Number of training faces: {n_train}\")\n",
409411
"print(f\"Number of validation faces: {n_val}\")\n",
410412
"\n",
411413
"train_faces = faces[:n_train]\n",
412414
"val_faces = faces[n_train:]\n",
413415
"\n",
414-
"train_frames = frames[:n_train*31]\n",
415-
"val_frames = frames[n_train*31:]\n",
416+
"train_frames = frames[: n_train * 31]\n",
417+
"val_frames = frames[n_train * 31 :]\n",
416418
"\n",
417419
"print(train_faces[-4:])\n",
418420
"print(val_faces[:4])\n",
@@ -1488,13 +1490,14 @@
14881490
],
14891491
"source": [
14901492
"def save_csv(pairs, filepath):\n",
1491-
" with open(filepath, mode='w', newline='') as file:\n",
1493+
" with open(filepath, mode=\"w\", newline=\"\") as file:\n",
14921494
" writer = csv.writer(file)\n",
14931495
" writer.writerow([\"face\", \"frame\"])\n",
14941496
" for face, frame in pairs:\n",
14951497
" writer.writerow([face, frame])\n",
14961498
" print(f\"Saved {len(pairs)} pairs to {filepath}\")\n",
14971499
"\n",
1500+
"\n",
14981501
"save_csv(match_train_pairs, \"data/rococo2v3-dev/train_match_pairs.csv\")\n",
14991502
"save_csv(mismatch_train_pairs, \"data/rococo2v3-dev/train_mismatch_pairs.csv\")\n",
15001503
"save_csv(match_val_pairs, \"data/rococo2v3-dev/val_match_pairs.csv\")\n",
@@ -1529,14 +1532,23 @@
15291532
}
15301533
],
15311534
"source": [
1532-
"used_frames = set([\n",
1533-
" *(frame for _, frame in match_train_pairs),\n",
1534-
" *(frame for _, frame in mismatch_train_pairs),\n",
1535-
" *(frame for _, frame in match_val_pairs),\n",
1536-
" *(frame for _, frame in mismatch_val_pairs),\n",
1537-
"])\n",
1535+
"used_frames = set(\n",
1536+
" [\n",
1537+
" *(frame for _, frame in match_train_pairs),\n",
1538+
" *(frame for _, frame in mismatch_train_pairs),\n",
1539+
" *(frame for _, frame in match_val_pairs),\n",
1540+
" *(frame for _, frame in mismatch_val_pairs),\n",
1541+
" ]\n",
1542+
")\n",
15381543
"\n",
1539-
"len(used_frames), sum((len(match_train_pairs), len(mismatch_train_pairs), len(match_val_pairs), len(mismatch_val_pairs)))"
1544+
"len(used_frames), sum(\n",
1545+
" (\n",
1546+
" len(match_train_pairs),\n",
1547+
" len(mismatch_train_pairs),\n",
1548+
" len(match_val_pairs),\n",
1549+
" len(mismatch_val_pairs),\n",
1550+
" )\n",
1551+
")"
15401552
]
15411553
},
15421554
{

0 commit comments

Comments
 (0)