From b6cafdac217ccb25b0aae7f0c3eaecd45f8a567d Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 19:37:14 -0500 Subject: [PATCH 1/7] added SymbolTable --- mfst/__init__.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index 8079348..d1b4d4a 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -2,7 +2,9 @@ # # Wrapper for OpenFst that supports defining custom semirings in python # and drawing FSTs in ipython notebooks - +# +# Edit 2021 Eric Fosler-Lussier to create SymbolTable class, allow for different input/output symbol tables, +# allow compilation import openfst_wrapper_backend as _backend from collections import namedtuple as _namedtuple, deque as _deque @@ -158,7 +160,7 @@ class FST(object): Wraps a mutable FST class """ - def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _fst=None): + def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, output_string_mapper=None, *, _fst=None): if semiring_class is None: semiring_class = PythonValueSemiringWeight elif type(semiring_class) is not type: @@ -201,6 +203,10 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _ # and passing that through ord() and chr() to print them in the graphics that we are drawing. So if that is being # used, then this will get set to true inside of add_arc self._string_mapper = string_mapper + # add output string mapper, which can also be the same alphabet + self._output_string_mapper = output_string_mapper + if self._output_string_mapper is None: + self._output_string_mapper = string_mapper def _make_weight(self, w): if isinstance(w, self._semiring_class): @@ -1025,3 +1031,44 @@ def make_label(x): ''') return ''.join(ret2) + +class SymbolTable(): + """ + Creates a symbol table that maps strings to ids; can function as a callable + to pass as input to FST constructor. + """ + def __init__(self,mutableOnFly=False,strict=False): + self.__sym2id = {} + self.__id2sym = {} + self.__mutableOnFly = mutableOnFly + self.__strict = strict + + def add_symbol(self,sym): + if sym in sym2id: + if strict: + raise Exception('Symbol '+sym+' already in symbol table') + return self.__sym2id[sym] + newid=len(self.__sym2id) + self.__sym2id[sym]=newid + self.__id2sym[newid]=sym + return newid + + def get_symbol(self,sym): + if sym in sym2id: + return sym2id[sym] + elif self.__mutableOnFly: + return self.add_symbol(sym) + else: + raise Exception('Symbol '+sym+' not in symbol table') + + def __get_item__(self, sym): + return get_symbol(self,sym) + + def __call__(self,symid): + if symid in self.__id2sym: + return self.__id2sym[symid] + else: + return symid + + + \ No newline at end of file From d67fed05aaf8167aa8275714a7499cd7d19fd200 Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 21:37:26 -0500 Subject: [PATCH 2/7] added compiler, output symbols --- mfst/__init__.py | 115 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 15 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index d1b4d4a..f72cb00 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -172,6 +172,7 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, outp semiring_class = BooleanSemiringWeight acceptor = True string_mapper = f._string_mapper + output_string_mapper = f._output_string_mapper else: assert issubclass(semiring_class, AbstractSemiringWeight), "first argument is not iterable or a semiring class" @@ -238,7 +239,8 @@ def constructor(self, _fst=None, **kwargs): _fst=_fst, semiring_class=self._semiring_class, acceptor=self._acceptor, - string_mapper=self._string_mapper + string_mapper=self._string_mapper, + output_string_mapper=self._output_string_mapper ) params.update(kwargs) return type(self)(**params) @@ -288,18 +290,24 @@ def get_unique_output_string(self): else: mapper = lambda x: x + if self._output_string_mapper is not None: + omapper = self._output_string_mapper + else: + omapper = lambda x: x + + while state != -1: edges = list(self.get_arcs(state)) if len(edges) != 1: raise RuntimeError("FST does not contain exactly one path") l = edges[0].output_label if l != 0: # the epsilon state - ret.append(mapper(l)) + ret.append(omapper(l)) if edges[0].nextstate in seen: raise RuntimeError("FST contains cycle") seen.add(state) state = edges[0].nextstate - if mapper is chr: + if omapper is chr: return ''.join(ret) return ret @@ -364,8 +372,8 @@ def add_arc(self, from_state, to_state, if isinstance(output_label, str): assert len(output_label) == 1, "FST string labels can only be a single character" output_label = ord(output_label) - if self._string_mapper is None: - self._string_mapper = chr + if self._output_string_mapper is None: + self._output_string_mapper = chr if self._acceptor: # acceptors are machines with the same input and output label if output_label == 0: # if not set just copy the value @@ -706,7 +714,7 @@ def lift(self, semiring=None, converter=None): else: converter = lambda x: x - ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper) + ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper, output_string_mapper=self._output_string_mapper) zero = self.semiring_zero for i in range(self.num_states): ret.add_state() # would be nice if this did not need to be called in a loop @@ -826,6 +834,7 @@ def __getstate__(self): 'semiring_class': self._semiring_class, 'acceptor': self._acceptor, 'string_mapper': self._string_mapper, + 'output_string_mapper': self._output_string_mapper, 'num_states': self.num_states, 'arcs': [[tuple(x) for x in self.get_arcs(s)] for s in self.states], 'initial_state': self.initial_state, @@ -836,11 +845,13 @@ def __setstate__(self, d): semiring_class=d['semiring_class'], acceptor=d['acceptor'], string_mapper=d['string_mapper'], + output_string_mapper=d['output_string_mapper'] ) self._fst = f._fst self._semiring_class = f._semiring_class self._acceptor = f._acceptor self._string_mapper = f._string_mapper + self._output_string_mapper = f._output_string_mapper for i in range(d['num_states']): self.add_state() self.initial_state = d['initial_state'] @@ -904,6 +915,20 @@ def make_label(x): else: make_label = str + if self._output_string_mapper is not None: + if self._output_string_mapper is chr: + def make_olabel(x): + if x == 32: + return '(spc)' + elif x < 32: + return str(x) + else: + return chr(x) + else: + make_olabel = self._output_string_mapper + else: + make_olabel = str + for sid in range(self.num_states): to = defaultdict(list) for arc in self.get_arcs(sid): @@ -920,7 +945,7 @@ def make_label(x): if arc.output_label == 0: label += '\u03B5' else: - label += make_label(arc.output_label) + label += make_olabel(arc.output_label) if one != arc.weight: label += f'/{arc.weight}' to[arc.nextstate].append(label) @@ -1032,7 +1057,7 @@ def make_label(x): ''') return ''.join(ret2) -class SymbolTable(): +class SymbolTable(object): """ Creates a symbol table that maps strings to ids; can function as a callable to pass as input to FST constructor. @@ -1044,7 +1069,7 @@ def __init__(self,mutableOnFly=False,strict=False): self.__strict = strict def add_symbol(self,sym): - if sym in sym2id: + if sym in self.__sym2id: if strict: raise Exception('Symbol '+sym+' already in symbol table') return self.__sym2id[sym] @@ -1054,21 +1079,81 @@ def add_symbol(self,sym): return newid def get_symbol(self,sym): - if sym in sym2id: - return sym2id[sym] + if sym in self.__sym2id: + return self.__sym2id[sym] elif self.__mutableOnFly: return self.add_symbol(sym) else: raise Exception('Symbol '+sym+' not in symbol table') - def __get_item__(self, sym): - return get_symbol(self,sym) + def __getitem__(self, sym): + return self.get_symbol(sym) def __call__(self,symid): if symid in self.__id2sym: return self.__id2sym[symid] else: return symid - - \ No newline at end of file +import re + +def compiler(strings,isymbols=None,osymbols=None,acceptor=False,add_symbols=True): + + # create isymbols, osymbols + if add_symbols and isymbols is None: + isymbols=SymbolTable(mutableOnFly=add_symbols) + isymbols.add_symbol("-") + if add_symbols and not acceptor and osymbols is None: + osymbols=SymbolTable(mutableOnFly=add_symbols) + osymbols.add_symbol("-") + + if (acceptor): + f=FST(acceptor=acceptor,string_mapper=isymbols) + else: + f=FST(acceptor=acceptor, string_mapper=isymbols, output_string_mapper=osymbols) + + # always have 0 be the initial state per OpenFST compilation standards + states={0: f.add_state()} + f.set_initial_state(states[0]) + + for s in strings: + parts=re.split('\s+',s) + weight=1 + max=5 + if acceptor: + max=4 + + parts[0]=int(parts[0]) + if len(parts)>2: + parts[1]=int(parts[1]) + if len(parts)>max: + raise Exception('Syntax error: '+s) + if len(parts)==max: + weight=float(parts[max-1]) + if len(parts)>=max-1: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + + if parts[1] not in states: + states[parts[1]]=f.add_state() + s1=states[parts[1]] + isym=isymbols[parts[2]] + if acceptor: + f.add_arc(s0,s1,weight,isym) + else: + osym=osymbols[parts[3]] + f.add_arc(s0,s1,weight,isym,osym) + elif len(parts)==2: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + f.set_final_weight(s0,parts[1]) + elif len(parts)==1: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + f.set_final_weight(s0) + else: + raise Exception('Syntax error: '+s) + return f From 8f1d8710f3e3996dd98a8961c634ea2f725ea7f9 Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 21:55:00 -0500 Subject: [PATCH 3/7] allow for diff symbol tables --- mfst/__init__.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index f72cb00..1127c4b 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -940,12 +940,14 @@ def make_olabel(x): label += '\u03B5' # epsilon else: label += make_label(arc.input_label) - if arc.input_label != arc.output_label: - label += ':' - if arc.output_label == 0: - label += '\u03B5' - else: - label += make_olabel(arc.output_label) + if arc.output_label == 0: + olabel = '\u03B5' # epsilon + else: + olabel = make_olabel(arc.output_label) + + if label != olabel: + label += ':'+olabel + if one != arc.weight: label += f'/{arc.weight}' to[arc.nextstate].append(label) From 5578f185a3764a48384f950ba6f52fc3a59a62ce Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 22:13:11 -0500 Subject: [PATCH 4/7] allow for diff symbol tables --- mfst/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index 1127c4b..c0b4359 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -927,7 +927,7 @@ def make_olabel(x): else: make_olabel = self._output_string_mapper else: - make_olabel = str + make_olabel = make_label for sid in range(self.num_states): to = defaultdict(list) @@ -935,7 +935,7 @@ def make_olabel(x): if arc.nextstate == -1: continue - label = '' + label = 'L:' if arc.input_label == 0: label += '\u03B5' # epsilon else: From b6f0b5e56bea77f1e1ed516e6da4cb4c0323a107 Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 22:27:25 -0500 Subject: [PATCH 5/7] allow for diff symbol tables --- mfst/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index c0b4359..e24cef2 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -935,7 +935,7 @@ def make_olabel(x): if arc.nextstate == -1: continue - label = 'L:' + label = '' if arc.input_label == 0: label += '\u03B5' # epsilon else: @@ -945,6 +945,7 @@ def make_olabel(x): else: olabel = make_olabel(arc.output_label) + print(label+' '+olabel) if label != olabel: label += ':'+olabel From d496d4ba470107bdc4dec22e775d72aa5104a0f3 Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 22:36:35 -0500 Subject: [PATCH 6/7] allow for diff symbol tables --- mfst/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index e24cef2..3cac5a8 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -945,7 +945,6 @@ def make_olabel(x): else: olabel = make_olabel(arc.output_label) - print(label+' '+olabel) if label != olabel: label += ':'+olabel @@ -1107,8 +1106,7 @@ def compiler(strings,isymbols=None,osymbols=None,acceptor=False,add_symbols=True isymbols=SymbolTable(mutableOnFly=add_symbols) isymbols.add_symbol("-") if add_symbols and not acceptor and osymbols is None: - osymbols=SymbolTable(mutableOnFly=add_symbols) - osymbols.add_symbol("-") + osymbols=isymbols # this is to handle bug in printing if (acceptor): f=FST(acceptor=acceptor,string_mapper=isymbols) From c882f72bbc7e75f46e4c444023ce44b3345f96bd Mon Sep 17 00:00:00 2001 From: Eric Fosler-Lussier Date: Thu, 28 Jan 2021 23:06:07 -0500 Subject: [PATCH 7/7] allow for diff symbol tables --- mfst/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mfst/__init__.py b/mfst/__init__.py index 3cac5a8..45e5eee 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -1106,8 +1106,9 @@ def compiler(strings,isymbols=None,osymbols=None,acceptor=False,add_symbols=True isymbols=SymbolTable(mutableOnFly=add_symbols) isymbols.add_symbol("-") if add_symbols and not acceptor and osymbols is None: - osymbols=isymbols # this is to handle bug in printing - + osymbols=SymbolTable(mutableOnFly=add_symbols) # this is to handle bug in printing + osymbols.add_symbol("-") + if (acceptor): f=FST(acceptor=acceptor,string_mapper=isymbols) else: