diff --git a/src/diskimg.c b/src/diskimg.c index c2c6e5a..89456a9 100644 --- a/src/diskimg.c +++ b/src/diskimg.c @@ -22,6 +22,11 @@ ssize_t diskimg_write(struct diskimg *diskimg, return write(diskimg->fd, data, size); } +int diskimg_flush(struct diskimg *diskimg) +{ + return fdatasync(diskimg->fd); +} + int diskimg_init(struct diskimg *diskimg, const char *file_path) { diskimg->fd = open(file_path, O_RDWR); diff --git a/src/diskimg.h b/src/diskimg.h index f03b669..c513201 100644 --- a/src/diskimg.h +++ b/src/diskimg.h @@ -17,5 +17,6 @@ ssize_t diskimg_write(struct diskimg *diskimg, void *data, off_t offset, size_t size); +int diskimg_flush(struct diskimg *diskimg); int diskimg_init(struct diskimg *diskimg, const char *file_path); void diskimg_exit(struct diskimg *diskimg); diff --git a/src/virtio-blk.c b/src/virtio-blk.c index d94f321..3b1452c 100644 --- a/src/virtio-blk.c +++ b/src/virtio-blk.c @@ -70,106 +70,182 @@ static void virtio_blk_enable_vq(struct virtq *vq) dev->vq_thread_started = true; } -static ssize_t virtio_blk_write(struct virtio_blk_dev *dev, - void *data, - off_t offset, - size_t size) +/* Snapshot of one descriptor in a chain. We copy the volatile guest fields + * once so subsequent decisions cannot tear against a concurrent guest write. + */ +struct desc_snap { + uint64_t addr; + uint32_t len; + uint16_t flags; + uint16_t id; +}; + +/* Walk a chain starting at the supplied head, snapshotting each descriptor + * into out[]. cap is the maximum supported chain length; the caller passes + * vq->info.size to mirror the "seen >= size" guard in the reference VMM and + * defend against a malformed chain. Returns the count on success, or 0 if the + * chain is malformed (NULL on virtq_get_avail mid-chain, or longer than cap). + * On any return the head has been consumed, so the caller is still obligated + * to publish USED. + */ +static size_t virtio_blk_walk_chain(struct virtq *vq, + struct vring_packed_desc *head, + struct desc_snap *out, + size_t cap) { - return diskimg_write(dev->diskimg, data, offset, size); + out[0].addr = head->addr; + out[0].len = head->len; + out[0].flags = head->flags; + out[0].id = head->id; + size_t n = 1; + while (out[n - 1].flags & VRING_DESC_F_NEXT) { + if (n >= cap) + return 0; + struct vring_packed_desc *next = virtq_get_avail(vq); + if (!next) + return 0; + out[n].addr = next->addr; + out[n].len = next->len; + out[n].flags = next->flags; + out[n].id = next->id; + n++; + } + return n; } -static ssize_t virtio_blk_read(struct virtio_blk_dev *dev, - void *data, - off_t offset, - size_t size) +static uint8_t virtio_blk_handle_io(struct virtio_blk_dev *dev, + vm_t *v, + const struct virtio_blk_req *req, + const struct desc_snap *chain, + size_t n, + bool needs_write, + uint32_t *out_written) { - return diskimg_read(dev->diskimg, data, offset, size); + /* sector * 512 must not overflow before any segment is dispatched. */ + uint64_t cur_off; + if (__builtin_mul_overflow(req->sector, (uint64_t) 512, &cur_off)) + return VIRTIO_BLK_S_IOERR; + + /* writable_total is reported as used.len (uint32_t in the packed-ring + * descriptor) plus one byte for the status descriptor we add later, so + * accumulate in 64 bits and reject any total that wouldn't fit. + */ + uint64_t writable_total = 0; + for (size_t i = 1; i < n - 1; i++) { + const struct desc_snap *seg = &chain[i]; + bool is_writable = (seg->flags & VRING_DESC_F_WRITE) != 0; + if (is_writable != needs_write) + return VIRTIO_BLK_S_IOERR; + + void *buf = vm_guest_buf(v, seg->addr, seg->len); + if (!buf) + return VIRTIO_BLK_S_IOERR; + + uint64_t end; + if (__builtin_add_overflow(cur_off, (uint64_t) seg->len, &end) || + end > (uint64_t) dev->diskimg->size) + return VIRTIO_BLK_S_IOERR; + + ssize_t got; + if (needs_write) + got = diskimg_read(dev->diskimg, buf, (off_t) cur_off, seg->len); + else + got = diskimg_write(dev->diskimg, buf, (off_t) cur_off, seg->len); + if (got < 0 || (size_t) got != (size_t) seg->len) + return VIRTIO_BLK_S_IOERR; + + cur_off += seg->len; + if (needs_write) { + if (__builtin_add_overflow(writable_total, (uint64_t) seg->len, + &writable_total) || + writable_total > (uint64_t) UINT32_MAX - 1) + return VIRTIO_BLK_S_IOERR; + } + } + *out_written = (uint32_t) writable_total; + return VIRTIO_BLK_S_OK; } static void virtio_blk_complete_request(struct virtq *vq) { struct virtio_blk_dev *dev = (struct virtio_blk_dev *) vq->dev; vm_t *v = container_of(dev, vm_t, virtio_blk_dev); - uint8_t status; - struct vring_packed_desc *desc; - struct virtio_blk_req req; + struct vring_packed_desc *head; - /* Wire-format header is type/reserved/sector only; the rest of struct - * virtio_blk_req is host bookkeeping and must not be overwritten by the - * guest. */ + /* Wire-format header is type/reserved/sector only; the trailing fields + * of struct virtio_blk_req are host bookkeeping. + */ const size_t hdr_sz = offsetof(struct virtio_blk_req, data); - while ((desc = virtq_get_avail(vq))) { - struct vring_packed_desc *used_desc = desc; - ssize_t io_bytes = 0; - - void *hdr = vm_guest_buf(v, desc->addr, hdr_sz); - if (!hdr || desc->len < hdr_sz) - return; - memcpy(&req, hdr, hdr_sz); - if (req.type == VIRTIO_BLK_T_IN || req.type == VIRTIO_BLK_T_OUT) { - if (!virtq_check_next(desc)) - return; - desc = virtq_get_avail(vq); - req.data_size = desc->len; - req.data = vm_guest_buf(v, desc->addr, req.data_size); - - /* Validate that the request fits in the backing store. Both the - * shift (sector*512) and the addition (offset+data_size) must not - * overflow, and the end must be within diskimg->size. Any failure - * yields VIRTIO_BLK_S_IOERR with no data transferred. */ - uint64_t off, end; - bool io_ok = false; - if (req.data && !__builtin_mul_overflow(req.sector, 512, &off) && - !__builtin_add_overflow(off, req.data_size, &end) && - end <= (uint64_t) dev->diskimg->size) { - if (req.type == VIRTIO_BLK_T_IN) - io_bytes = virtio_blk_read(dev, req.data, (off_t) off, - req.data_size); - else - io_bytes = virtio_blk_write(dev, req.data, (off_t) off, - req.data_size); - /* A short read/write leaves part of the guest buffer stale, - * so treat anything less than the full request as IOERR. */ - io_ok = io_bytes >= 0 && (size_t) io_bytes == req.data_size; - } - status = io_ok ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR; - } else { - status = VIRTIO_BLK_S_UNSUPP; - } - if (!virtq_check_next(desc)) - return; - desc = virtq_get_avail(vq); - /* The status descriptor must advertise at least one device-writable - * byte; otherwise we'd clobber memory the guest did not offer. + while ((head = virtq_get_avail(vq))) { + struct desc_snap chain[VIRTQ_SIZE]; + /* Walker cap is the array bound, not the guest-controlled + * vq->info.size — virtio-pci clamps that on writes, but pass + * VIRTQ_SIZE here too as defense in depth against ABI drift. */ - if (desc->len < 1) + size_t n = virtio_blk_walk_chain(vq, head, chain, VIRTQ_SIZE); + if (n == 0) { + /* Malformed chain — buffer ID lives on the last descriptor and we + * never reached it. Publishing USED with chain[0].id risks pointing + * the driver at an unrelated in-flight chain. Stalling the queue is + * the lesser evil. + */ return; - req.status = vm_guest_buf(v, desc->addr, 1); - if (!req.status) - return; - *req.status = status; - /* used.len is total bytes the device wrote into device-writable buffers - * across the chain: the 1-byte status is always written, plus the data - * buffer on a successful IN. On any error we report only the status - * byte so the guest does not consume stale data. + } + + /* Default response: IOERR using the chain's last-descriptor id (the + * buffer ID) and len=1. Single-descriptor chains have head == last + * so this is the head's id. */ - size_t written = 1; - if (status == VIRTIO_BLK_S_OK && req.type == VIRTIO_BLK_T_IN) - written += (size_t) io_bytes; - used_desc->len = (uint32_t) written; - /* Buffer ID lives on the chain's last descriptor in packed virtqueues; - * echo it back into the head/used slot so the driver can match the - * completion to its in-flight request. + uint8_t status_byte = VIRTIO_BLK_S_IOERR; + uint16_t buffer_id = chain[n - 1].id; + uint32_t used_len = 1; + uint8_t *status_ptr = NULL; + + if (n < 2) + goto publish; + + /* Last descriptor of the chain owns the buffer ID and is the status + * descriptor; it must be device-writable with at least one byte. */ - used_desc->id = desc->id; - /* Single-writer slot until USED publishes, so a plain load of the - * current flags is safe. Release-store the new value so id and len are - * visible to the guest before the USED flag flip. + const struct desc_snap *status_desc = &chain[n - 1]; + if (!(status_desc->flags & VRING_DESC_F_WRITE) || status_desc->len < 1) + goto publish; + status_ptr = vm_guest_buf(v, status_desc->addr, 1); + if (!status_ptr) + goto publish; + + /* Header descriptor must be device-readable and span at least the + * wire-format header. */ - uint16_t new_flags = - used_desc->flags ^ (uint16_t) (1U << VRING_PACKED_DESC_F_USED); - __atomic_store_n(&used_desc->flags, new_flags, __ATOMIC_RELEASE); + if ((chain[0].flags & VRING_DESC_F_WRITE) || chain[0].len < hdr_sz) + goto publish; + void *hdr = vm_guest_buf(v, chain[0].addr, hdr_sz); + if (!hdr) + goto publish; + + struct virtio_blk_req req; + memcpy(&req, hdr, hdr_sz); + + if (req.type == VIRTIO_BLK_T_IN || req.type == VIRTIO_BLK_T_OUT) { + bool needs_write = req.type == VIRTIO_BLK_T_IN; + uint32_t writable = 0; + status_byte = virtio_blk_handle_io(dev, v, &req, chain, n, + needs_write, &writable); + used_len = writable + 1; + } else if (req.type == VIRTIO_BLK_T_FLUSH) { + status_byte = diskimg_flush(dev->diskimg) < 0 ? VIRTIO_BLK_S_IOERR + : VIRTIO_BLK_S_OK; + used_len = 1; + } else { + status_byte = VIRTIO_BLK_S_UNSUPP; + used_len = 1; + } + + publish: + if (status_ptr) + *status_ptr = status_byte; + virtq_publish_used(head, buffer_id, used_len); __atomic_fetch_or(&dev->virtio_pci_dev.config.isr_cap.isr_status, VIRTIO_PCI_ISR_QUEUE, __ATOMIC_RELEASE); } @@ -225,7 +301,11 @@ int virtio_blk_init_pci(struct virtio_blk_dev *virtio_blk_dev, virtio_pci_set_pci_hdr(dev, VIRTIO_PCI_DEVICE_ID_BLK, VIRTIO_BLK_PCI_CLASS, virtio_blk_dev->irq_num); virtio_pci_set_virtq(dev, virtio_blk_dev->vq, VIRTIO_BLK_VIRTQ_NUM); - virtio_pci_add_feature(dev, 0); + /* FLUSH is required for guest fsync to be honored: with the bit clear the + * Linux driver runs in writeback-without-barrier mode and a host crash can + * lose data the guest believed durable. + */ + virtio_pci_add_feature(dev, 1ULL << VIRTIO_BLK_F_FLUSH); virtio_pci_enable(dev); return 0; } @@ -245,6 +325,10 @@ void virtio_blk_exit(struct virtio_blk_dev *dev) throw_err("Failed to wake virtio-blk worker"); if (dev->vq_thread_started) pthread_join(dev->vq_avail_thread, NULL); + /* Honor guest barrier semantics on clean shutdown: writes that came back as + * VIRTIO_BLK_S_OK could still be in the host page cache. + */ + diskimg_flush(dev->diskimg); diskimg_exit(dev->diskimg); virtio_pci_exit(&dev->virtio_pci_dev); close(dev->irqfd); diff --git a/src/virtio-blk.h b/src/virtio-blk.h index 1c8b414..75d99e4 100644 --- a/src/virtio-blk.h +++ b/src/virtio-blk.h @@ -13,12 +13,16 @@ #define VIRTIO_BLK_VIRTQ_NUM 1 #define VIRTIO_BLK_PCI_CLASS 0x018000 +/* Wire-format header is the first three fields (type/reserved/sector); the + * trailing host-only bookkeeping is filled in by the device emulator from the + * descriptor chain and never read from guest memory. + */ struct virtio_blk_req { uint32_t type; uint32_t reserved; uint64_t sector; uint8_t *data; - uint16_t data_size; + uint32_t data_size; uint8_t *status; }; diff --git a/src/virtio-net.c b/src/virtio-net.c index e8338c9..455327d 100644 --- a/src/virtio-net.c +++ b/src/virtio-net.c @@ -50,14 +50,26 @@ static bool virtio_net_poll_tx(struct virtio_net_dev *dev) { struct pollfd pollfds[] = { [0] = {.fd = dev->tx_ioeventfd, .events = POLLIN}, - [1] = {.fd = dev->tapfd, .events = POLLOUT}, - [2] = {.fd = dev->stopfd, .events = POLLIN}, + [1] = {.fd = dev->stopfd, .events = POLLIN}, + [2] = {.fd = dev->tapfd, .events = dev->tx_wait_for_tap ? POLLOUT : 0}, }; int ret = poll(pollfds, 3, -1); + if (ret <= 0 || (pollfds[1].revents & POLLIN)) + return false; - return ret > 0 && (pollfds[0].revents & POLLIN) && - (pollfds[1].revents & POLLOUT) && !(pollfds[2].revents & POLLIN); + bool tx_kick = pollfds[0].revents & POLLIN; + bool tap_writable = pollfds[2].revents & POLLOUT; + + if (tx_kick) { + /* Drain the level-triggered ioeventfd so the next poll(2) blocks + * until the guest kicks again. */ + uint64_t n; + ssize_t ignored = read(dev->tx_ioeventfd, &n, sizeof(n)); + (void) ignored; + } + + return tx_kick || tap_writable; } static void *virtio_net_vq_avail_handler_rx(void *arg) @@ -66,7 +78,7 @@ static void *virtio_net_vq_avail_handler_rx(void *arg) struct virtio_net_dev *dev = (struct virtio_net_dev *) vq->dev; while (!virtio_net_stop_requested(dev)) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_ENABLE; + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_ENABLE); if (virtio_net_poll_rx(dev)) virtq_handle_avail(vq); } @@ -79,7 +91,7 @@ static void *virtio_net_vq_avail_handler_tx(void *arg) struct virtio_net_dev *dev = (struct virtio_net_dev *) vq->dev; while (!virtio_net_stop_requested(dev)) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_ENABLE; + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_ENABLE); if (virtio_net_poll_tx(dev)) virtq_handle_avail(vq); } @@ -146,91 +158,236 @@ static void virtio_net_notify_used_tx(struct virtq *vq) throw_err("Failed to write the irqfd"); } +/* Snapshot of one descriptor in a chain, copied once so guest-side races can + * not tear our subsequent decisions. + */ +struct net_desc_snap { + uint64_t addr; + uint32_t len; + uint16_t flags; + uint16_t id; +}; + +/* Walk the chain rooted at head, copying each descriptor into out[]. cap bounds + * the chain length. Returns the count on success or 0 on malformed chain (NULL + * mid-walk or chain longer than cap). The head has been consumed regardless, so + * the caller must publish USED. + */ +static size_t net_walk_chain(struct virtq *vq, + struct vring_packed_desc *head, + struct net_desc_snap *out, + size_t cap) +{ + out[0].addr = head->addr; + out[0].len = head->len; + out[0].flags = head->flags; + out[0].id = head->id; + size_t n = 1; + while (out[n - 1].flags & VRING_DESC_F_NEXT) { + if (n >= cap) + return 0; + struct vring_packed_desc *next = virtq_get_avail(vq); + if (!next) + return 0; + out[n].addr = next->addr; + out[n].len = next->len; + out[n].flags = next->flags; + out[n].id = next->id; + n++; + } + return n; +} + void virtio_net_complete_request_rx(struct virtq *vq) { struct virtio_net_dev *dev = (struct virtio_net_dev *) vq->dev; vm_t *v = container_of(dev, vm_t, virtio_net_dev); - struct vring_packed_desc *desc; - - while ((desc = virtq_get_avail(vq)) != NULL) { - size_t virtio_header_len = sizeof(struct virtio_net_hdr_v1); - /* desc lives in guest-writable memory; snapshot the length we'll - * validate and use so a concurrent guest write cannot widen the - * access past the bounds check. */ - uint32_t buf_len = desc->len; - uint8_t *data = vm_guest_buf(v, desc->addr, buf_len); - if (!data || buf_len < virtio_header_len) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; + struct vring_packed_desc *head; + const size_t hdr_len = sizeof(struct virtio_net_hdr_v1); + + while ((head = virtq_get_avail(vq))) { + struct net_desc_snap chain[VIRTQ_SIZE]; + /* See virtio-blk for why we cap at VIRTQ_SIZE rather than the + * guest-controlled vq->info.size. + */ + size_t n = net_walk_chain(vq, head, chain, VIRTQ_SIZE); + if (n == 0) { + /* Malformed chain — buffer ID lives on the last descriptor and we + * never reached it. Publishing USED with chain[0].id (which the + * driver may have left stale on a multi-desc head) could cause the + * driver to look up the wrong in-flight request and + * advance next_used_idx by an unrelated chain length. Stalling is + * the lesser evil; a misbehaving driver hangs only itself. + */ + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_DISABLE); return; } - struct virtio_net_hdr_v1 *virtio_hdr = - (struct virtio_net_hdr_v1 *) data; - memset(virtio_hdr, 0, sizeof(struct virtio_net_hdr_v1)); + uint16_t buffer_id = chain[n - 1].id; + uint32_t used_len = 0; + + /* Build iov over device-writable buffers; reject chains that mix + * directions (per RX rules every descriptor must be writable). + */ + struct iovec iov[VIRTQ_SIZE]; + size_t iov_n = 0; + size_t writable_total = 0; + bool ok = true; + for (size_t i = 0; i < n; i++) { + if (!(chain[i].flags & VRING_DESC_F_WRITE)) { + ok = false; + break; + } + void *buf = vm_guest_buf(v, chain[i].addr, chain[i].len); + if (!buf) { + ok = false; + break; + } + iov[iov_n].iov_base = buf; + iov[iov_n].iov_len = chain[i].len; + iov_n++; + writable_total += chain[i].len; + } + if (!ok || writable_total < hdr_len) + goto rx_publish; + + /* Reserve the leading hdr_len bytes for the virtio-net header, which + * the device fills in itself. Walk the iov advancing past the header + * so readv writes the frame starting right after it; zero-len entries + * are fine to leave in the array. + */ + struct virtio_net_hdr_v1 net_hdr = {0}; + /* Without VIRTIO_NET_F_MRG_RXBUF the driver always sees one buffer per + * packet, so num_buffers is always 1 here. + */ + net_hdr.num_buffers = 1; + + size_t hdr_remaining = hdr_len; + size_t hdr_offset = 0; + for (size_t i = 0; i < iov_n && hdr_remaining > 0; i++) { + size_t take = + iov[i].iov_len < hdr_remaining ? iov[i].iov_len : hdr_remaining; + memcpy(iov[i].iov_base, (uint8_t *) &net_hdr + hdr_offset, take); + iov[i].iov_base = (uint8_t *) iov[i].iov_base + take; + iov[i].iov_len -= take; + hdr_remaining -= take; + hdr_offset += take; + } - virtio_hdr->num_buffers = 1; + ssize_t got = readv(dev->tapfd, iov, (int) iov_n); + if (got < 0) + goto rx_publish; + used_len = (uint32_t) hdr_len + (uint32_t) got; - ssize_t read_bytes = read(dev->tapfd, data + virtio_header_len, - buf_len - virtio_header_len); - if (read_bytes < 0) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; - return; - } - desc->len = virtio_header_len + read_bytes; - /* Single-descriptor chain: head and last alias the same slot, so - * the buffer ID the driver wrote in desc->id is already correct. - * Release-store flags so the len update lands before the guest - * observes the USED flag flip. */ - uint16_t new_flags = - desc->flags ^ (uint16_t) (1U << VRING_PACKED_DESC_F_USED); - __atomic_store_n(&desc->flags, new_flags, __ATOMIC_RELEASE); + rx_publish: + virtq_publish_used(head, buffer_id, used_len); __atomic_fetch_or(&dev->virtio_pci_dev.config.isr_cap.isr_status, VIRTIO_PCI_ISR_QUEUE, __ATOMIC_RELEASE); + if (used_len == 0) { + /* Wedged — back off so we don't hot-loop on a broken chain. */ + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_DISABLE); + return; + } + /* Process exactly one chain per call so the worker can re-poll the tap + * before draining the next packet. + */ return; } - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; - return; + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_DISABLE); } void virtio_net_complete_request_tx(struct virtq *vq) { struct virtio_net_dev *dev = (struct virtio_net_dev *) vq->dev; vm_t *v = container_of(dev, vm_t, virtio_net_dev); - struct vring_packed_desc *desc; - while ((desc = virtq_get_avail(vq)) != NULL) { - size_t virtio_header_len = sizeof(struct virtio_net_hdr_v1); - /* See rx path: snapshot len before bounds check to defeat TOCTOU. */ - uint32_t buf_len = desc->len; - uint8_t *data = vm_guest_buf(v, desc->addr, buf_len); - - if (!data || buf_len < virtio_header_len) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; + struct vring_packed_desc *head; + const size_t hdr_len = sizeof(struct virtio_net_hdr_v1); + + /* We have been woken up; clear the retry-pending flag here so every + * exit path (publish-USED, malformed-chain return, queue-empty break) + * leaves it false. The transient writev() EAGAIN path below is the + * only one that sets it back to true, and it returns immediately. + */ + dev->tx_wait_for_tap = false; + + while (true) { + uint16_t avail_idx = vq->next_avail_idx; + bool used_wrap_count = vq->used_wrap_count; + if (!(head = virtq_get_avail(vq))) + break; + struct net_desc_snap chain[VIRTQ_SIZE]; + size_t n = net_walk_chain(vq, head, chain, VIRTQ_SIZE); + if (n == 0) { + /* See RX path: don't publish USED with a stale id. */ + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_DISABLE); return; } + uint16_t buffer_id = chain[n - 1].id; + + /* Build iov over device-readable buffers; reject chains that mix + * directions (per TX rules every descriptor must be readable). + */ + struct iovec iov[VIRTQ_SIZE]; + size_t iov_n = 0; + size_t total = 0; + bool ok = true; + for (size_t i = 0; i < n; i++) { + if (chain[i].flags & VRING_DESC_F_WRITE) { + ok = false; + break; + } + void *buf = vm_guest_buf(v, chain[i].addr, chain[i].len); + if (!buf) { + ok = false; + break; + } + iov[iov_n].iov_base = buf; + iov[iov_n].iov_len = chain[i].len; + iov_n++; + total += chain[i].len; + } + if (!ok || total < hdr_len) + goto tx_publish; + + /* Strip the virtio-net header from the front of the iov before + * handing it to writev — the TAP device wants raw frame bytes. + */ + size_t skip_remaining = hdr_len; + for (size_t i = 0; i < iov_n && skip_remaining > 0; i++) { + size_t take = iov[i].iov_len < skip_remaining ? iov[i].iov_len + : skip_remaining; + iov[i].iov_base = (uint8_t *) iov[i].iov_base + take; + iov[i].iov_len -= take; + skip_remaining -= take; + } - uint8_t *actual_data = data + virtio_header_len; - size_t actual_data_len = buf_len - virtio_header_len; - - struct iovec iov[1]; - iov[0].iov_base = actual_data; - iov[0].iov_len = actual_data_len; - - ssize_t write_bytes = writev(dev->tapfd, iov, 1); - if (write_bytes < 0) { - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; - return; + ssize_t wrote = writev(dev->tapfd, iov, (int) iov_n); + if (wrote < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + /* Keep the chain in-flight and retry it once the TAP fd is + * writable again instead of dropping the guest packet. + */ + vq->next_avail_idx = avail_idx; + vq->used_wrap_count = used_wrap_count; + dev->tx_wait_for_tap = true; + virtq_set_guest_event_flags(vq, + VRING_PACKED_EVENT_FLAG_DISABLE); + return; + } } - /* TX buffers are device-readable only, so zero bytes were written - * to device-writable parts. */ - desc->len = 0; - uint16_t new_flags = - desc->flags ^ (uint16_t) (1U << VRING_PACKED_DESC_F_USED); - __atomic_store_n(&desc->flags, new_flags, __ATOMIC_RELEASE); + + tx_publish: + /* TX buffers are device-readable only — no bytes were written to + * device-writable buffers, so used.len = 0. + */ + virtq_publish_used(head, buffer_id, 0); __atomic_fetch_or(&dev->virtio_pci_dev.config.isr_cap.isr_status, VIRTIO_PCI_ISR_QUEUE, __ATOMIC_RELEASE); - return; + /* Drain the entire virtq in one wakeup. The ioeventfd was already read + * in poll_tx, so if we returned early any remaining chains would sit + * untouched until the next guest kick. + */ } - vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; - return; + virtq_set_guest_event_flags(vq, VRING_PACKED_EVENT_FLAG_DISABLE); } static struct virtq_ops virtio_net_ops[VIRTIO_NET_VIRTQ_NUM] = { diff --git a/src/virtio-net.h b/src/virtio-net.h index 09d391b..be50fb4 100644 --- a/src/virtio-net.h +++ b/src/virtio-net.h @@ -22,6 +22,7 @@ struct virtio_net_dev { pthread_t tx_thread; bool rx_thread_started; bool tx_thread_started; + bool tx_wait_for_tap; bool enable; }; diff --git a/src/virtio-pci.c b/src/virtio-pci.c index a2de270..908d7fe 100644 --- a/src/virtio-pci.c +++ b/src/virtio-pci.c @@ -114,10 +114,18 @@ static void virtio_pci_space_write(struct virtio_pci_dev *dev, offset <= VIRTIO_PCI_COMMON_Q_USEDHI) { uint16_t select = dev->config.common_cfg.queue_select; uint64_t info_offset = offset - VIRTIO_PCI_COMMON_Q_SIZE; - if (select < dev->config.common_cfg.num_queues) + if (select < dev->config.common_cfg.num_queues) { memcpy((void *) ((uintptr_t) &dev->vq[select].info + info_offset), data, size); + /* Clamp guest-supplied queue_size to what we advertised + * (VIRTQ_SIZE). Without this, a guest writing a larger + * avalue would let chain walks blow past the + * VIRTQ_SIZE-sized stack arrays in the device emulators. + */ + if (dev->vq[select].info.size > VIRTQ_SIZE) + dev->vq[select].info.size = VIRTQ_SIZE; + } } /* guest notify buffer avail */ else if (offset == @@ -141,9 +149,10 @@ static void virtio_pci_space_read(struct virtio_pci_dev *dev, if (offset < offsetof(struct virtio_pci_config, dev_cfg)) { if (offset == offsetof(struct virtio_pci_config, isr_cap) && size <= sizeof(dev->config.isr_cap.isr_status)) { - /* Read-and-clear in one atomic step so a worker thread's - * concurrent fetch_or on isr_status cannot be lost between the - * load and the zero-back. Acquire pairs with workers' release. */ + /* Read-and-clear in one atomic step so a worker thread's concurrent + * fetch_or on isr_status cannot be lost between the load and the + * zero-back. Acquire pairs with workers' release. + */ uint32_t status = __atomic_exchange_n( &dev->config.isr_cap.isr_status, 0, __ATOMIC_ACQUIRE); memcpy(data, &status, size); @@ -208,7 +217,8 @@ static void virtio_pci_set_cap(struct virtio_pci_dev *dev, uint8_t next) caps[VIRTIO_PCI_CAP_ISR_CFG]->length = sizeof(struct virtio_pci_isr_cap); /* FIXME: The offset for the dev-specific configuration MUST be 4-byte - * aligned */ + * aligned + */ caps[VIRTIO_PCI_CAP_DEVICE_CFG]->offset = offsetof(struct virtio_pci_config, dev_cfg); caps[VIRTIO_PCI_CAP_DEVICE_CFG]->length = 0; diff --git a/src/virtq.c b/src/virtq.c index 10cd131..52c166d 100644 --- a/src/virtq.c +++ b/src/virtq.c @@ -21,7 +21,6 @@ void virtq_enable(struct virtq *vq) void virtq_disable(struct virtq *vq) {} -#define VIRTQ_SIZE 128 void virtq_init(struct virtq *vq, void *dev, struct virtq_ops *ops) { vq->info.size = VIRTQ_SIZE; @@ -58,6 +57,28 @@ struct vring_packed_desc *virtq_get_avail(struct virtq *vq) return desc; } +void virtq_publish_used(struct vring_packed_desc *head, + uint16_t id, + uint32_t len) +{ + /* Buffer ID belongs in the head/used slot in packed virtqueues; the + * driver uses it to match the completion to its in-flight request. The + * len update must be visible before the USED flag flip, so write id and + * len with relaxed stores and then use a release-store on flags. */ + head->id = id; + head->len = len; + uint16_t flags = head->flags ^ (uint16_t) (1U << VRING_PACKED_DESC_F_USED); + __atomic_store_n(&head->flags, flags, __ATOMIC_RELEASE); +} + +void virtq_set_guest_event_flags(struct virtq *vq, uint16_t value) +{ + /* The consumer side reads guest_event->flags with __ATOMIC_ACQUIRE in + * virtq_handle_avail; pair with a release-store so completion writes + * land before the suppression-flag transition is visible. */ + __atomic_store_n(&vq->guest_event->flags, value, __ATOMIC_RELEASE); +} + void virtq_handle_avail(struct virtq *vq) { if (!vq->info.enable) diff --git a/src/virtq.h b/src/virtq.h index 4a524bd..81b74df 100644 --- a/src/virtq.h +++ b/src/virtq.h @@ -4,6 +4,10 @@ #include #include +/* Maximum descriptors per packed virtqueue. Also bounds chain length so a + * malformed chain cannot loop the device. */ +#define VIRTQ_SIZE 128 + struct virtq; struct virtq_ops { @@ -36,6 +40,10 @@ struct virtq { struct vring_packed_desc *virtq_get_avail(struct virtq *vq); bool virtq_check_next(struct vring_packed_desc *desc); +void virtq_publish_used(struct vring_packed_desc *head, + uint16_t id, + uint32_t len); +void virtq_set_guest_event_flags(struct virtq *vq, uint16_t value); void virtq_enable(struct virtq *vq); void virtq_disable(struct virtq *vq); void virtq_complete_request(struct virtq *vq);