Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 148 additions & 14 deletions mfst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#
# Wrapper for OpenFst that supports defining custom semirings in python
# and drawing FSTs in ipython notebooks

#
# Edit 2021 Eric Fosler-Lussier to create SymbolTable class, allow for different input/output symbol tables,
# allow compilation

import openfst_wrapper_backend as _backend
from collections import namedtuple as _namedtuple, deque as _deque
Expand Down Expand Up @@ -158,7 +160,7 @@ class FST(object):
Wraps a mutable FST class
"""

def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _fst=None):
def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, output_string_mapper=None, *, _fst=None):
if semiring_class is None:
semiring_class = PythonValueSemiringWeight
elif type(semiring_class) is not type:
Expand All @@ -170,6 +172,7 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _
semiring_class = BooleanSemiringWeight
acceptor = True
string_mapper = f._string_mapper
output_string_mapper = f._output_string_mapper
else:
assert issubclass(semiring_class, AbstractSemiringWeight), "first argument is not iterable or a semiring class"

Expand Down Expand Up @@ -201,6 +204,10 @@ def __init__(self, semiring_class=None, acceptor=False, string_mapper=None, *, _
# and passing that through ord() and chr() to print them in the graphics that we are drawing. So if that is being
# used, then this will get set to true inside of add_arc
self._string_mapper = string_mapper
# add output string mapper, which can also be the same alphabet
self._output_string_mapper = output_string_mapper
if self._output_string_mapper is None:
self._output_string_mapper = string_mapper

def _make_weight(self, w):
if isinstance(w, self._semiring_class):
Expand Down Expand Up @@ -232,7 +239,8 @@ def constructor(self, _fst=None, **kwargs):
_fst=_fst,
semiring_class=self._semiring_class,
acceptor=self._acceptor,
string_mapper=self._string_mapper
string_mapper=self._string_mapper,
output_string_mapper=self._output_string_mapper
)
params.update(kwargs)
return type(self)(**params)
Expand Down Expand Up @@ -282,18 +290,24 @@ def get_unique_output_string(self):
else:
mapper = lambda x: x

if self._output_string_mapper is not None:
omapper = self._output_string_mapper
else:
omapper = lambda x: x


while state != -1:
edges = list(self.get_arcs(state))
if len(edges) != 1:
raise RuntimeError("FST does not contain exactly one path")
l = edges[0].output_label
if l != 0: # the epsilon state
ret.append(mapper(l))
ret.append(omapper(l))
if edges[0].nextstate in seen:
raise RuntimeError("FST contains cycle")
seen.add(state)
state = edges[0].nextstate
if mapper is chr:
if omapper is chr:
return ''.join(ret)
return ret

Expand Down Expand Up @@ -358,8 +372,8 @@ def add_arc(self, from_state, to_state,
if isinstance(output_label, str):
assert len(output_label) == 1, "FST string labels can only be a single character"
output_label = ord(output_label)
if self._string_mapper is None:
self._string_mapper = chr
if self._output_string_mapper is None:
self._output_string_mapper = chr
if self._acceptor:
# acceptors are machines with the same input and output label
if output_label == 0: # if not set just copy the value
Expand Down Expand Up @@ -700,7 +714,7 @@ def lift(self, semiring=None, converter=None):
else:
converter = lambda x: x

ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper)
ret = FST(semiring, acceptor=self._acceptor, string_mapper=self._string_mapper, output_string_mapper=self._output_string_mapper)
zero = self.semiring_zero
for i in range(self.num_states):
ret.add_state() # would be nice if this did not need to be called in a loop
Expand Down Expand Up @@ -820,6 +834,7 @@ def __getstate__(self):
'semiring_class': self._semiring_class,
'acceptor': self._acceptor,
'string_mapper': self._string_mapper,
'output_string_mapper': self._output_string_mapper,
'num_states': self.num_states,
'arcs': [[tuple(x) for x in self.get_arcs(s)] for s in self.states],
'initial_state': self.initial_state,
Expand All @@ -830,11 +845,13 @@ def __setstate__(self, d):
semiring_class=d['semiring_class'],
acceptor=d['acceptor'],
string_mapper=d['string_mapper'],
output_string_mapper=d['output_string_mapper']
)
self._fst = f._fst
self._semiring_class = f._semiring_class
self._acceptor = f._acceptor
self._string_mapper = f._string_mapper
self._output_string_mapper = f._output_string_mapper
for i in range(d['num_states']):
self.add_state()
self.initial_state = d['initial_state']
Expand Down Expand Up @@ -898,6 +915,20 @@ def make_label(x):
else:
make_label = str

if self._output_string_mapper is not None:
if self._output_string_mapper is chr:
def make_olabel(x):
if x == 32:
return '(spc)'
elif x < 32:
return str(x)
else:
return chr(x)
else:
make_olabel = self._output_string_mapper
else:
make_olabel = make_label

for sid in range(self.num_states):
to = defaultdict(list)
for arc in self.get_arcs(sid):
Expand All @@ -909,12 +940,14 @@ def make_label(x):
label += '\u03B5' # epsilon
else:
label += make_label(arc.input_label)
if arc.input_label != arc.output_label:
label += ':'
if arc.output_label == 0:
label += '\u03B5'
else:
label += make_label(arc.output_label)
if arc.output_label == 0:
olabel = '\u03B5' # epsilon
else:
olabel = make_olabel(arc.output_label)

if label != olabel:
label += ':'+olabel

if one != arc.weight:
label += f'/{arc.weight}'
to[arc.nextstate].append(label)
Expand Down Expand Up @@ -1025,3 +1058,104 @@ def make_label(x):
</script>
''')
return ''.join(ret2)

class SymbolTable(object):
"""
Creates a symbol table that maps strings to ids; can function as a callable
to pass as input to FST constructor.
"""
def __init__(self,mutableOnFly=False,strict=False):
self.__sym2id = {}
self.__id2sym = {}
self.__mutableOnFly = mutableOnFly
self.__strict = strict

def add_symbol(self,sym):
if sym in self.__sym2id:
if strict:
raise Exception('Symbol '+sym+' already in symbol table')
return self.__sym2id[sym]
newid=len(self.__sym2id)
self.__sym2id[sym]=newid
self.__id2sym[newid]=sym
return newid

def get_symbol(self,sym):
if sym in self.__sym2id:
return self.__sym2id[sym]
elif self.__mutableOnFly:
return self.add_symbol(sym)
else:
raise Exception('Symbol '+sym+' not in symbol table')

def __getitem__(self, sym):
return self.get_symbol(sym)

def __call__(self,symid):
if symid in self.__id2sym:
return self.__id2sym[symid]
else:
return symid

import re

def compiler(strings,isymbols=None,osymbols=None,acceptor=False,add_symbols=True):

# create isymbols, osymbols
if add_symbols and isymbols is None:
isymbols=SymbolTable(mutableOnFly=add_symbols)
isymbols.add_symbol("-")
if add_symbols and not acceptor and osymbols is None:
osymbols=SymbolTable(mutableOnFly=add_symbols) # this is to handle bug in printing
osymbols.add_symbol("-")

if (acceptor):
f=FST(acceptor=acceptor,string_mapper=isymbols)
else:
f=FST(acceptor=acceptor, string_mapper=isymbols, output_string_mapper=osymbols)

# always have 0 be the initial state per OpenFST compilation standards
states={0: f.add_state()}
f.set_initial_state(states[0])

for s in strings:
parts=re.split('\s+',s)
weight=1
max=5
if acceptor:
max=4

parts[0]=int(parts[0])
if len(parts)>2:
parts[1]=int(parts[1])
if len(parts)>max:
raise Exception('Syntax error: '+s)
if len(parts)==max:
weight=float(parts[max-1])
if len(parts)>=max-1:
if parts[0] not in states:
states[parts[0]]=f.add_state()
s0=states[parts[0]]

if parts[1] not in states:
states[parts[1]]=f.add_state()
s1=states[parts[1]]
isym=isymbols[parts[2]]
if acceptor:
f.add_arc(s0,s1,weight,isym)
else:
osym=osymbols[parts[3]]
f.add_arc(s0,s1,weight,isym,osym)
elif len(parts)==2:
if parts[0] not in states:
states[parts[0]]=f.add_state()
s0=states[parts[0]]
f.set_final_weight(s0,parts[1])
elif len(parts)==1:
if parts[0] not in states:
states[parts[0]]=f.add_state()
s0=states[parts[0]]
f.set_final_weight(s0)
else:
raise Exception('Syntax error: '+s)
return f