Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions host_sim.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@

import logging
import asyncio

from sim_server import SimServer
from not_tcp.host import StreamProxy


log = logging.getLogger(__name__)


class HostSimulator(SimServer, StreamProxy):
# Multiple inheritance is not a *crime*, it's just an abuse of the rules.
# Tax avoidance is not tax evasion!
pass


async def run_server(port):
import sys
import ntcp_http
dut = ntcp_http.NtcpHttpServer()

with HostSimulator(dut, dut.tx, dut.rx) as srv:
server = await asyncio.start_server(
client_connected_cb=srv.client_connected, host="localhost",
port=port)
sys.stderr.write(f"listening on port {port}\n")
log.info(f"listening on port {port}\n")
await server.serve_forever()


Expand Down
42 changes: 29 additions & 13 deletions not_tcp/host.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass
import struct
from enum import IntFlag
from typing import Optional
from asyncio import StreamReader, StreamWriter
import asyncio
import sys
import logging
import struct

log = logging.getLogger(__name__)


class Flag(IntFlag):
Expand Down Expand Up @@ -104,6 +106,7 @@ def from_bytes(cls, buf: bytes) -> (Optional["Packet"], bytes):
# Use as superclass; subclass to simulator or real
class StreamProxy:
lock = asyncio.Lock()
request_number = 0

def send(self, b: bytes()):
# Must be implemented by subclass
Expand All @@ -115,14 +118,19 @@ def recv(self) -> bytes:

def client_connected(
self, reader: StreamReader, writer: StreamWriter):
asyncio.create_task(self.client_loop(reader, writer))
r = self.request_number = self.request_number + 1
log.info(f'client {r} connected')
asyncio.create_task(self.client_loop(r, reader, writer))

async def client_loop(self, reader: StreamReader, writer: StreamWriter):
async def client_loop(self, number: int,
reader: StreamReader, writer: StreamWriter):
async with self.lock, asyncio.TaskGroup() as tg:
tg.create_task(self.run_inbound(reader))
tg.create_task(self.run_outbound(writer))
log.info(f'starting client {number} handler')
tg.create_task(self.run_inbound(number, reader))
tg.create_task(self.run_outbound(number, writer))
log.info(f'completed client {number} handlers')

async def run_inbound(self, reader: StreamReader):
async def run_inbound(self, number: int, reader: StreamReader):
p1 = Packet(flags=Flag.START, stream_id=1, body=bytes())
self.send(p1.to_bytes())
want_bytes = 256
Expand All @@ -145,15 +153,23 @@ async def run_inbound(self, reader: StreamReader):
# Input is done, in theory
p3 = Packet(flags=Flag.END, stream_id=1, body=bytes())
self.send(p3.to_bytes())
log.info(f"client {number} closed inbound connection")

async def run_outbound(self, writer: StreamWriter):
async def run_outbound(self, number: int, writer: StreamWriter):
olog = log.getChild("outbound")
total_bytes = 0
buffer = bytes()
packet_count = 0
while True:
rcvd = self.recv() # Has its own timeout, but isn't async. So:
await asyncio.sleep(0)
buffer += rcvd
(p, rem) = Packet.from_bytes(buffer)
if len(buffer) > 0:
olog.debug(
f"buffer contains bytes: "
f"{total_bytes}:{total_bytes+len(buffer)}\n")

buffer_len = len(buffer)
(p, buffer) = Packet.from_bytes(buffer)
consumed = (buffer_len - len(buffer))
total_bytes += consumed
if p is None:
continue
buffer = rem
Expand Down
13 changes: 5 additions & 8 deletions not_tcp/host_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import sys
import pytest
import asyncio
from ntcp_http import NtcpHttpServer
import pytest

from amaranth import Module
from amaranth.lib.wiring import Component, In, Out
from amaranth.lib import stream

from host_sim import HostSimulator
from http_server import capitalizer
from not_tcp.host import Packet, Flag
from sim_server import SimServer
from not_tcp.not_tcp import StreamStop
from http_server import capitalizer
from ntcp_http import NtcpHttpServer
from sim_server import SimServer


pytest_plugins = ('pytest_asyncio',)

Expand Down Expand Up @@ -65,12 +65,10 @@ def DISABLED_test_capitalize_server():
received_bytes = bytes()
received_body = bytes()
packets = []
import sys
for i in range(100):
received_bytes += srv.recv()
(packet, remainder) = Packet.from_bytes(received_bytes)
if packet is not None:
sys.stderr.write(f"{packet}\n")
received_bytes = remainder
packets += [packet]
received_body += packet.body
Expand All @@ -88,7 +86,6 @@ def DISABLED_test_capitalize_server():
received_bytes += srv.recv()
(packet, remainder) = Packet.from_bytes(received_bytes)
if packet is not None:
sys.stderr.write(f"{packet}\n")
received_bytes = remainder
packets += [packet]
received_body += packet.body
Expand Down
2 changes: 0 additions & 2 deletions ntcp_http_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ def test_sim():

received_bytes = bytes()
packets = []
import sys
# We shouldn't have more than 100 packets for this test.
for i in range(100):
received_bytes += srv.recv()
(packet, remainder) = Packet.from_bytes(received_bytes)
if packet is not None:
sys.stderr.write(f"packet: {packet}\n")
received_bytes = remainder
packets += [packet]
if packet.end:
Expand Down
15 changes: 9 additions & 6 deletions sim_server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import queue
import sys
import logging
import traceback
from threading import Thread

from amaranth.sim import Simulator

from stream_fixtures import StreamSender, StreamCollector

log = logging.getLogger(__name__)


class SimServer:
"""
Expand Down Expand Up @@ -100,25 +102,26 @@ def _run_sim(self, sim):
def runnable():
try:
# Uncomment this line, and indent the next, to get debug info.
# with sim.write_vcd("testout.vcd"):
sim.run()
with sim.write_vcd("testout.vcd"):
sim.run()
except Exception as e:
sys.stderr.write(f"error in Amaranth simulation: {e}\n")
log.error("error in Amaranth simulation: ", e)
# Try to force shutdown:
self._sender.die = True
raise e

return runnable

def __exit__(self, exe_type, exe_val, exe_traceback, **kwargs):
if exe_traceback is not None:
traceback.print_tb(exe_traceback)

assert self._sim_thread is not None
# Shutting down the data input should shut down the simulator;
# the data input is driving the tick.
# self._data_in.shutdown()
# .shutdown() is not available on python3.11,
# so we have to use a flag.
if exe_traceback is not None:
traceback.print_tb(exe_traceback)

self._sender.die = True
self._sim_thread.join()
19 changes: 8 additions & 11 deletions stream_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
Test fixtures for sending and receiving in streams.
"""
import sys
import time
import random
import queue
import logging
from typing import Iterable


log = logging.getLogger(__name__)

__all__ = ["StreamCollector", "StreamSender"]


Expand Down Expand Up @@ -99,15 +101,10 @@ async def collector(ctx):
try:
q.put(batch, block=False)
except queue.Full:
sys.stderr.write(
f"queue full, saving {len(batch)} bytes "
"for later\n"
)
countup = 0
continue
except Exception as e:
sys.stderr.write(
f"error in sending data from sim: {e}\n")
log.error("error in sending data from sim: ", e)
return
batch = bytes()
countup = 0
Expand Down Expand Up @@ -150,7 +147,6 @@ class StreamSender:
# Flag bit, to kill the send_queue_active thread
die: bool = False


def __init__(self,
stream,
random_delay=False,
Expand Down Expand Up @@ -199,10 +195,11 @@ async def sender(ctx):
except queue.Empty:
data = bytes()
except queue.ShutDown:
sys.stderr.write("queue is shut down\n")
log.info("write-to-sim queue is shut down")
return
except Exception as e:
sys.stderr.write(f"unexpected exception: {e}\n")
log.error(
"write-to-sim queue unexpected exception: ", e)
raise e

if isinstance(data, str):
Expand Down
Loading