forked from danlou/LMMS
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexp_mapping.py
More file actions
executable file
·133 lines (94 loc) · 3.95 KB
/
exp_mapping.py
File metadata and controls
executable file
·133 lines (94 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import argparse
import logging
from time import time
from functools import lru_cache
from datetime import datetime
import numpy as np
from nltk.corpus import wordnet as wn
import spacy
nlp = spacy.load('en_core_web_sm')
from bert_as_service import bert_embed
from vectorspace import SensesVSM
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%d-%b-%y %H:%M:%S')
@lru_cache()
def wn_sensekey2synset(sensekey):
lemma = sensekey.split('%')[0]
for synset in wn.synsets(lemma):
for lemma in synset.lemmas():
if lemma.key() == sensekey:
return synset
return None
def get_sent_info(merge_ents=False):
sent_info = {'tokens': [], 'lemmas': [], 'pos': [], 'sentence': ''}
sent_info['sentence'] = input('Input Sentence (\'q\' to exit):\n')
doc = nlp(sent_info['sentence'])
if merge_ents:
for ent in doc.ents:
ent.merge()
for tok in doc:
sent_info['tokens'].append(tok.text.replace(' ', '_'))
# sent_info['tokens'].append(tok.text)
sent_info['lemmas'].append(tok.lemma_)
sent_info['pos'].append(tok.pos_)
sent_info['tokenized_sentence'] = ' '.join(sent_info['tokens'])
return sent_info
def map_senses(svsm, tokens, postags=[], lemmas=[], use_postag=False, use_lemma=False):
"""Given loaded LMMS and a list of tokens, returns a list of scored sensekeys."""
matches = []
if len(tokens) != len(postags): # mismatched
use_postag = False
if len(tokens) != len(lemmas): # mismatched
use_lemma = False
sent_bert = bert_embed([' '.join(tokens)], merge_strategy='mean')[0]
for idx in range(len(tokens)):
idx_vec = sent_bert[idx][1]
idx_vec = idx_vec / np.linalg.norm(idx_vec)
if svsm.ndims == 1024:
# idx_vec = idx_vec
pass
elif svsm.ndims == 1024+1024:
idx_vec = np.hstack((idx_vec, idx_vec))
idx_vec = idx_vec / np.linalg.norm(idx_vec)
idx_matches = []
if use_lemma and use_postag:
idx_matches = svsm.match_senses(idx_vec, lemmas[idx], postags[idx], topn=None)
elif use_lemma:
idx_matches = svsm.match_senses(idx_vec, lemmas[idx], None, topn=None)
elif use_postag:
idx_matches = svsm.match_senses(idx_vec, None, postags[idx], topn=None)
else:
idx_matches = svsm.match_senses(idx_vec, None, None, topn=None)
matches.append(idx_matches)
return matches
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Concept Mapping Demonstration.')
parser.add_argument('-sv_path', help='Path to sense vectors', required=True)
args = parser.parse_args()
logging.info('Loading SensesVSM ...')
senses_vsm = SensesVSM(args.sv_path, normalize=True)
while True:
sent_info = get_sent_info()
if sent_info['sentence'] == 'q':
break
elif len(sent_info['sentence']) == 0:
continue
matches = map_senses(senses_vsm,
sent_info['tokens'],
sent_info['pos'],
sent_info['lemmas'],
use_lemma=False,
use_postag=False)
for idx, idx_matches in enumerate(matches):
print()
print('TOK: %s | POS: %s | LEM: %s' % (sent_info['tokens'][idx],
sent_info['pos'][idx],
sent_info['lemmas'][idx]))
print('Top 10 Matches (out of %d):' % len(idx_matches))
for sk_idx, (sk, score) in enumerate(idx_matches[:10]):
synset = wn_sensekey2synset(sk)
print('#%d - %.3f %s %s' % (sk_idx + 1, score, sk, synset))
print('DEF: %s' % synset.definition())
print()