-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_data.py
More file actions
34 lines (23 loc) · 871 Bytes
/
get_data.py
File metadata and controls
34 lines (23 loc) · 871 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch
from pathlib import Path
label_tags = ["course", "department", "faculty", "other", "project", "staff", "student"]
def read_data(dir_path):
dir_path = Path(dir_path)
texts = []
labels = []
for label_path in label_tags:
for html_file in (dir_path/label_path).glob('**/*.html'):
texts.append(html_file.read_text())
labels.append(label_tags.index(label_path))
return texts, labels
class web_dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)