Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 8 additions & 37 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This workflow tests IMPSY across multiple Python and TensorFlow versions.
# This workflow tests IMPSY across Python versions and platforms
# using the poetry-locked dependency versions.
# Full matrix runs on PRs; a smaller smoke test runs on pushes to main.

name: Install and run IMPSY
Expand All @@ -18,7 +19,7 @@ jobs:
if: github.event_name == 'push'
strategy:
matrix:
platform: [ubuntu-latest, macos-latest]
platform: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -49,42 +50,14 @@ jobs:
method: PURGE

# Full compatibility matrix: runs on PRs to main
# Tests Python 3.11 + 3.12 on Ubuntu + macOS with poetry-locked deps (TF 2.19.1)
full-matrix:
if: github.event_name == 'pull_request'
strategy:
fail-fast: false
matrix:
include:
# TF 2.16.2 — Python 3.12 only (3.11 crashes with SIGABRT)
- platform: ubuntu-latest
python-version: "3.12"
tf-version: "2.16.2"
tfp-version: "0.24.0"
- platform: macos-latest
python-version: "3.12"
tf-version: "2.16.2"
tfp-version: "0.24.0"
# TF 2.18.1 — Python 3.11, 3.12 — Ubuntu, macOS
- platform: ubuntu-latest
python-version: "3.11"
tf-version: "2.18.1"
tfp-version: "0.25.0"
- platform: ubuntu-latest
python-version: "3.12"
tf-version: "2.18.1"
tfp-version: "0.25.0"
- platform: macos-latest
python-version: "3.11"
tf-version: "2.18.1"
tfp-version: "0.25.0"
- platform: macos-latest
python-version: "3.12"
tf-version: "2.18.1"
tfp-version: "0.25.0"
# TF 2.20.0 — blocked on missing Flex delegate for SELECT_TF_OPS
# See https://github.com/google-ai-edge/LiteRT/issues/6458
# Re-enable once LiteRT supports Flex ops in ai_edge_litert

platform: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.11", "3.12"]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
Expand All @@ -94,10 +67,8 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
- name: Install dependencies with specific TF version
run: |
poetry install
poetry run pip install dm-tree>=0.1.9 tensorflow==${{ matrix.tf-version }} tensorflow-probability==${{ matrix.tfp-version }}
- name: Install dependencies
run: poetry install
- name: Lint with flake8
run: |
poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=examples
Expand Down
6 changes: 5 additions & 1 deletion impsy/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def get_tflite_interpreter(model_path: str):


def get_tflite_converter(model):
"""Return a TFLite converter from a Keras model."""
"""Return a TFLite converter from a Keras model.

Uses from_keras_model (works on TF 2.19.1+ with Keras 3). This preserves
the model's input names in the TFLite SignatureDef across Python versions.
"""
return tf.lite.TFLiteConverter.from_keras_model(model)


Expand Down
29 changes: 23 additions & 6 deletions impsy/mdrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,25 @@ def prepare(self) -> None:
else:
self.runner = None
self.interpreter.allocate_tensors()
self._input_index = {d['name']: d['index'] for d in self.interpreter.get_input_details()}
self._output_details = self.interpreter.get_output_details()
# Build input index mapping expected names to tensor indices.
# Input order is: inputs, state_h_0, state_c_0, state_h_1, state_c_1, ...
expected_names = ['inputs']
for i in range(self.n_layers):
expected_names.append(f'state_h_{i}')
expected_names.append(f'state_c_{i}')
input_details = sorted(self.interpreter.get_input_details(), key=lambda d: d['index'])
self._input_index = {name: d['index'] for name, d in zip(expected_names, input_details)}
# Build output index mapping by shape: MDN output has more columns than n_hidden_units.
output_details = self.interpreter.get_output_details()
self._mdn_output_index = None
state_outputs = []
for d in output_details:
if d['shape'][-1] != self.n_hidden_units:
self._mdn_output_index = d['index']
else:
state_outputs.append(d['index'])
state_outputs.sort()
self._state_output_indices = state_outputs


def _discover_output_keys(self):
Expand Down Expand Up @@ -434,11 +451,11 @@ def generate(self, prev_value: np.ndarray) -> np.ndarray:
for name, value in runner_input.items():
self.interpreter.set_tensor(self._input_index[name], self._to_numpy(value))
self.interpreter.invoke()
## Outputs ordered: [mdn_out, state_h_0, state_c_0, ...]
mdn_params = self.interpreter.get_tensor(self._output_details[0]['index']).squeeze()
## Extract outputs by shape-based indices
mdn_params = self.interpreter.get_tensor(self._mdn_output_index).squeeze()
for i in range(self.n_layers):
self.lstm_states[2 * i] = self.interpreter.get_tensor(self._output_details[1 + 2 * i]['index'])
self.lstm_states[2 * i + 1] = self.interpreter.get_tensor(self._output_details[2 + 2 * i]['index'])
self.lstm_states[2 * i] = self.interpreter.get_tensor(self._state_output_indices[2 * i])
self.lstm_states[2 * i + 1] = self.interpreter.get_tensor(self._state_output_indices[2 * i + 1])
# sample from the MDN:
new_sample = (
mdn.sample_from_output(
Expand Down
Loading
Loading