diff --git a/mfst/__init__.py b/mfst/__init__.py index 8079348..45e5eee 100644 --- a/mfst/__init__.py +++ b/mfst/__init__.py @@ -2,7 +2,9 @@ # # Wrapper for OpenFst that supports defining custom semirings in python # and drawing FSTs in ipython notebooks - +# +# Edit 2021 Eric Fosler-Lussier to create SymbolTable class, allow for different input/output symbol tables, +# allow compilation import openfst_wrapper_backend as _backend from collections import namedtuple as _namedtuple, deque as _deque @@ -158,7 +160,7 @@ class FST(object): Wraps a mutable FST class """ - def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _fst=None): + def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, output_string_mapper=None, *, _fst=None): if semiring_class is None: semiring_class = PythonValueSemiringWeight elif type(semiring_class) is not type: @@ -170,6 +172,7 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _ semiring_class = BooleanSemiringWeight acceptor = True string_mapper = f._string_mapper + output_string_mapper = f._output_string_mapper else: assert issubclass(semiring_class, AbstractSemiringWeight), "first argument is not iterable or a semiring class" @@ -201,6 +204,10 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _ # and passing that through ord() and chr() to print them in the graphics that we are drawing. So if that is being # used, then this will get set to true inside of add_arc self._string_mapper = string_mapper + # add output string mapper, which can also be the same alphabet + self._output_string_mapper = output_string_mapper + if self._output_string_mapper is None: + self._output_string_mapper = string_mapper def _make_weight(self, w): if isinstance(w, self._semiring_class): @@ -232,7 +239,8 @@ def constructor(self, _fst=None, **kwargs): _fst=_fst, semiring_class=self._semiring_class, acceptor=self._acceptor, - string_mapper=self._string_mapper + string_mapper=self._string_mapper, + output_string_mapper=self._output_string_mapper ) params.update(kwargs) return type(self)(**params) @@ -282,18 +290,24 @@ def get_unique_output_string(self): else: mapper = lambda x: x + if self._output_string_mapper is not None: + omapper = self._output_string_mapper + else: + omapper = lambda x: x + + while state != -1: edges = list(self.get_arcs(state)) if len(edges) != 1: raise RuntimeError("FST does not contain exactly one path") l = edges[0].output_label if l != 0: # the epsilon state - ret.append(mapper(l)) + ret.append(omapper(l)) if edges[0].nextstate in seen: raise RuntimeError("FST contains cycle") seen.add(state) state = edges[0].nextstate - if mapper is chr: + if omapper is chr: return ''.join(ret) return ret @@ -358,8 +372,8 @@ def add_arc(self, from_state, to_state, if isinstance(output_label, str): assert len(output_label) == 1, "FST string labels can only be a single character" output_label = ord(output_label) - if self._string_mapper is None: - self._string_mapper = chr + if self._output_string_mapper is None: + self._output_string_mapper = chr if self._acceptor: # acceptors are machines with the same input and output label if output_label == 0: # if not set just copy the value @@ -700,7 +714,7 @@ def lift(self, semiring=None, converter=None): else: converter = lambda x: x - ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper) + ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper, output_string_mapper=self._output_string_mapper) zero = self.semiring_zero for i in range(self.num_states): ret.add_state() # would be nice if this did not need to be called in a loop @@ -820,6 +834,7 @@ def __getstate__(self): 'semiring_class': self._semiring_class, 'acceptor': self._acceptor, 'string_mapper': self._string_mapper, + 'output_string_mapper': self._output_string_mapper, 'num_states': self.num_states, 'arcs': [[tuple(x) for x in self.get_arcs(s)] for s in self.states], 'initial_state': self.initial_state, @@ -830,11 +845,13 @@ def __setstate__(self, d): semiring_class=d['semiring_class'], acceptor=d['acceptor'], string_mapper=d['string_mapper'], + output_string_mapper=d['output_string_mapper'] ) self._fst = f._fst self._semiring_class = f._semiring_class self._acceptor = f._acceptor self._string_mapper = f._string_mapper + self._output_string_mapper = f._output_string_mapper for i in range(d['num_states']): self.add_state() self.initial_state = d['initial_state'] @@ -898,6 +915,20 @@ def make_label(x): else: make_label = str + if self._output_string_mapper is not None: + if self._output_string_mapper is chr: + def make_olabel(x): + if x == 32: + return '(spc)' + elif x < 32: + return str(x) + else: + return chr(x) + else: + make_olabel = self._output_string_mapper + else: + make_olabel = make_label + for sid in range(self.num_states): to = defaultdict(list) for arc in self.get_arcs(sid): @@ -909,12 +940,14 @@ def make_label(x): label += '\u03B5' # epsilon else: label += make_label(arc.input_label) - if arc.input_label != arc.output_label: - label += ':' - if arc.output_label == 0: - label += '\u03B5' - else: - label += make_label(arc.output_label) + if arc.output_label == 0: + olabel = '\u03B5' # epsilon + else: + olabel = make_olabel(arc.output_label) + + if label != olabel: + label += ':'+olabel + if one != arc.weight: label += f'/{arc.weight}' to[arc.nextstate].append(label) @@ -1025,3 +1058,104 @@ def make_label(x): ''') return ''.join(ret2) + +class SymbolTable(object): + """ + Creates a symbol table that maps strings to ids; can function as a callable + to pass as input to FST constructor. + """ + def __init__(self,mutableOnFly=False,strict=False): + self.__sym2id = {} + self.__id2sym = {} + self.__mutableOnFly = mutableOnFly + self.__strict = strict + + def add_symbol(self,sym): + if sym in self.__sym2id: + if strict: + raise Exception('Symbol '+sym+' already in symbol table') + return self.__sym2id[sym] + newid=len(self.__sym2id) + self.__sym2id[sym]=newid + self.__id2sym[newid]=sym + return newid + + def get_symbol(self,sym): + if sym in self.__sym2id: + return self.__sym2id[sym] + elif self.__mutableOnFly: + return self.add_symbol(sym) + else: + raise Exception('Symbol '+sym+' not in symbol table') + + def __getitem__(self, sym): + return self.get_symbol(sym) + + def __call__(self,symid): + if symid in self.__id2sym: + return self.__id2sym[symid] + else: + return symid + +import re + +def compiler(strings,isymbols=None,osymbols=None,acceptor=False,add_symbols=True): + + # create isymbols, osymbols + if add_symbols and isymbols is None: + isymbols=SymbolTable(mutableOnFly=add_symbols) + isymbols.add_symbol("-") + if add_symbols and not acceptor and osymbols is None: + osymbols=SymbolTable(mutableOnFly=add_symbols) # this is to handle bug in printing + osymbols.add_symbol("-") + + if (acceptor): + f=FST(acceptor=acceptor,string_mapper=isymbols) + else: + f=FST(acceptor=acceptor, string_mapper=isymbols, output_string_mapper=osymbols) + + # always have 0 be the initial state per OpenFST compilation standards + states={0: f.add_state()} + f.set_initial_state(states[0]) + + for s in strings: + parts=re.split('\s+',s) + weight=1 + max=5 + if acceptor: + max=4 + + parts[0]=int(parts[0]) + if len(parts)>2: + parts[1]=int(parts[1]) + if len(parts)>max: + raise Exception('Syntax error: '+s) + if len(parts)==max: + weight=float(parts[max-1]) + if len(parts)>=max-1: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + + if parts[1] not in states: + states[parts[1]]=f.add_state() + s1=states[parts[1]] + isym=isymbols[parts[2]] + if acceptor: + f.add_arc(s0,s1,weight,isym) + else: + osym=osymbols[parts[3]] + f.add_arc(s0,s1,weight,isym,osym) + elif len(parts)==2: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + f.set_final_weight(s0,parts[1]) + elif len(parts)==1: + if parts[0] not in states: + states[parts[0]]=f.add_state() + s0=states[parts[0]] + f.set_final_weight(s0) + else: + raise Exception('Syntax error: '+s) + return f