-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPOS_accuracy.py
More file actions
118 lines (102 loc) · 4.93 KB
/
POS_accuracy.py
File metadata and controls
118 lines (102 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
### import the taggers to evaluate
from turtle import pos
from nltk import pos_tag
import string
from sklearn import metrics
from flair.models import SequenceTagger
from flair.data import Sentence
tagger_FLAIR = SequenceTagger.load("C:/Users/gusta/Documents/GitHub/Reddit_MDA/RedditTaggerFinal150.pt")
def clean_sentence(sentence):
'''Takes a sentence and returns it in all lowercase, with punctuation removed, and emojis removed.'''
sentence = str(sentence).strip(string.punctuation).lower()
for emoticon in [":-)", ":)", ";-)", ":-P", ";-P", ":-p", ";-p", ":-(", ";-(", ":-O", "^^", "-.-", ":-$", ":-\\", ":-/", ":-|", ";-/", ";-\\",
":-[", ":-]", ":-§", "owo", "*.*", ";)", ":P", ":p", ";P", ";p", ":(", ";(", ":O", ":o", ":|", ";/", ";\\", ":[", ":]", ":§"]:
sentence = sentence.replace(emoticon, "")
## emoticons already counted (but not removed) in the analyse_sentence function
## emojis already counted (but not removed) in the analyse_sentence function
## links and URLs counted AND removed in the analyse_sentence function
return sentence
def tag_sentence(sentence):
'''Takes a sentence, cleans it with clean_sentence, and tags it using the FLAIR POS tagger.
Adds a look ahead/behind buffer of three items of type ("X", "X") to prevent negative indices and IndexErrors
Returns a list of tuples of (word, pos_tag).'''
cleaned_sentence = clean_sentence(sentence)
flair_sentence = Sentence(cleaned_sentence)
tagger_FLAIR.predict(flair_sentence)
token_list = []
for label in flair_sentence.get_labels('pos'):
if not label.value in ["''", "``"]:
token_list.append(tuple([label.data_point.text] + [label.value]))
empty_look = [("X", "X"), ("X", "X"), ("X", "X")]
tagged_sentence = empty_look + token_list + empty_look
return tagged_sentence
def sentence_tags(w_pos_list):
gold = [x[1] for x in w_pos_list]
words = [x[0] for x in w_pos_list]
nltk_tags = [x[1] for x in pos_tag(words)]
sent = " ".join(words)
flair_tags = [x[1] for x in tag_sentence(sent)][3:-3]
if len(gold) == len(nltk_tags) == len(flair_tags):
return(gold, nltk_tags, flair_tags)
else:
print("Unequal token numbers")
print(gold)
print(flair_tags)
return([],[],[])
gold = []
nltk_tags = []
flair_tags = []
with open("C:/Users/gusta/Documents/GitHub/Reddit_MDA/Tagged_JSONS/RC_2005-12_tagged_manual_Batch1_done.txt") as f:
for line in f:
if line.split("\t")[2] == "[]\n":
pass
else:
print(line.split("\t")[0])
sent_raw = line.split("\t")[2].strip("\n")
sent_split = sent_raw.strip("[]").replace("'", "").split("], [")
sentence = [x.split(", ") for x in sent_split]
gold += sentence_tags(sentence)[0]
nltk_tags += sentence_tags(sentence)[1]
flair_tags += sentence_tags(sentence)[2]
with open("C:/Users/gusta/Documents/GitHub/Reddit_MDA/Tagged_JSONS/RC_2005-12_tagged_manual_Batch2_done.txt") as g:
for line in g:
if line.split("\t")[2] == "[]\n":
pass
else:
print(line.split("\t")[0])
sent_raw = line.split("\t")[2].strip("\n")
sent_split = sent_raw.strip("[]").replace("'", "").split("], [")
sentence = [x.split(", ") for x in sent_split]
gold += sentence_tags(sentence)[0]
nltk_tags += sentence_tags(sentence)[1]
flair_tags += sentence_tags(sentence)[2]
with open("C:/Users/gusta/Documents/GitHub/Reddit_MDA/Tagged_JSONS/RC_2005-12_tagged_manual_Batch1_done.txt") as h:
for line in h:
if line.split("\t")[2] == "[]\n":
pass
else:
print(line.split("\t")[0])
sent_raw = line.split("\t")[2].strip("\n")
sent_split = sent_raw.strip("[]").replace("'", "").split("], [")
sentence = [x.split(", ") for x in sent_split]
gold += sentence_tags(sentence)[0]
nltk_tags += sentence_tags(sentence)[1]
flair_tags += sentence_tags(sentence)[2]
with open("C:/Users/gusta/Documents/GitHub/Reddit_MDA/Tagged_JSONS/RC_2005-12_tagged_manual_Batch1_done.txt") as i:
for line in i:
if line.split("\t")[2] == "[]\n":
pass
else:
print(line.split("\t")[0])
sent_raw = line.split("\t")[2].strip("\n")
sent_split = sent_raw.strip("[]").replace("'", "").split("], [")
sentence = [x.split(", ") for x in sent_split]
gold += sentence_tags(sentence)[0]
nltk_tags += sentence_tags(sentence)[1]
flair_tags += sentence_tags(sentence)[2]
nltk_classification = metrics.classification_report(gold, nltk_tags,)
flair_classification = metrics.classification_report(gold, flair_tags,)
with open("Tagging_accuracy_reports.txt", "w") as p:
p.write(nltk_classification)
p.write("\n\n\n")
p.write(flair_classification)