-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_context.py
More file actions
70 lines (52 loc) · 2.16 KB
/
run_context.py
File metadata and controls
70 lines (52 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import pandas as pd
from context import Context
from nn import *
import argparse
from collections import Counter
def initialize_model(train_df, params):
train_df = preprocess(train_df)
train_text = train_df["text"].values.tolist()
vocab = learn_vocab(train_text, params["vocab_size"])
train_tokens, SGT, count, SGT_dict = tokens_to_ids(train_text, vocab, params["SGT_path"])
params["num_SGT"] = count
print(count, "unique SGTs")
unique = list(set(SGT))
unique.sort()
sgt_W = cal_weights(SGT)
hate_W = cal_weights(train_df["hate"].tolist())
off_W = cal_weights(train_df["offensive"].tolist())
model = Context(params, vocab, sgt_W, hate_W, off_W)
batches = get_batches(train_tokens,
model.batch_size,
vocab.index("<pad>"),
train_df["hate"].tolist(),
train_df["offensive"].tolist(),
SGT=SGT)
model.train(batches)
return model, vocab, SGT_dict
def test_model(test_df, vocab, model, SGT_dict):
test_df = preprocess(test_df)
test_text = test_df["text"].values.tolist()
test_tokens, SGT, count, _ = tokens_to_ids(test_text, vocab, params["SGT_path"], SGT_dict)
batches = get_batches(test_tokens,
params["batch_size"],
vocab.index("<pad>"),
SGT=SGT)
test_predictions = model.predict_hate(batches, ["hate", "offensive"])
prediction_results(test_df, test_predictions, ["hate", "offensive"])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="Path to data; includes text, hate and offensive columns")
parser.add_argument("--params", help="Parameter files. should be a json file")
parser.add_argument("--test",)
args = parser.parse_args()
data = pd.read_csv(args.data)
test = pd.read_csv(args.test)
try:
params = json.load(open(args.params, 'r'))
except Exception:
print("Wrong params file")
exit(1)
model, vocab, SGT = initialize_model(data, params)
test_model(test, vocab, model, SGT)