Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/mun_abi/c
Submodule c updated 1 files
+27 −5 include/mun_abi.h
99 changes: 81 additions & 18 deletions crates/mun_abi/src/autogen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prelude::*;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));

use std::ffi::{c_void, CStr};
use std::os::raw::c_char;
use std::slice;

impl TypeInfo {
Expand All @@ -19,7 +20,7 @@ impl PartialEq for TypeInfo {
}
}

impl FunctionInfo {
impl FunctionSignature {
pub fn name(&self) -> &str {
unsafe { CStr::from_ptr(self.name) }
.to_str()
Expand All @@ -41,13 +42,16 @@ impl FunctionInfo {
pub fn return_type(&self) -> Option<&TypeInfo> {
unsafe { self.return_type.as_ref() }
}

pub unsafe fn fn_ptr(&self) -> *const c_void {
self.fn_ptr
}
}

impl ModuleInfo {
/// Returns the module's full `path`.
pub fn path(&self) -> &str {
unsafe { CStr::from_ptr(self.path) }
.to_str()
.expect("Module path contains invalid UTF8")
}

// /// Finds the type's fields that match `filter`.
// pub fn find_fields(&self, filter: fn(&&FieldInfo) -> bool) -> impl Iterator<Item = &FieldInfo> {
// self.fields.iter().map(|f| *f).filter(filter)
Expand All @@ -63,19 +67,6 @@ impl ModuleInfo {
// self.fields.iter().map(|f| *f)
// }

/// Finds the module's functions that match `filter`.
pub fn find_functions<F>(&self, filter: F) -> impl Iterator<Item = &FunctionInfo>
where
F: FnMut(&&FunctionInfo) -> bool,
{
self.functions().iter().filter(filter)
}

/// Retrieves the module's function with the specified `name`, if it exists.
pub fn get_function(&self, name: &str) -> Option<&FunctionInfo> {
self.functions().iter().find(|f| f.name() == name)
}

/// Retrieves the module's functions.
pub fn functions(&self) -> &[FunctionInfo] {
if self.num_functions == 0 {
Expand All @@ -85,3 +76,75 @@ impl ModuleInfo {
}
}
}

impl DispatchTable {
pub fn iter_mut(&mut self) -> impl Iterator<Item = (&mut *const c_void, &FunctionSignature)> {
if self.num_entries == 0 {
(&mut []).iter_mut().zip((&[]).iter())
} else {
let ptrs =
unsafe { slice::from_raw_parts_mut(self.fn_ptrs, self.num_entries as usize) };
let signatures =
unsafe { slice::from_raw_parts(self.signatures, self.num_entries as usize) };

ptrs.iter_mut().zip(signatures.iter())
}
}

pub fn ptrs_mut(&mut self) -> &mut [*const c_void] {
if self.num_entries == 0 {
&mut []
} else {
unsafe { slice::from_raw_parts_mut(self.fn_ptrs, self.num_entries as usize) }
}
}

pub fn signatures(&self) -> &[FunctionSignature] {
if self.num_entries == 0 {
&[]
} else {
unsafe { slice::from_raw_parts(self.signatures, self.num_entries as usize) }
}
}

pub unsafe fn get_ptr_unchecked(&self, idx: u32) -> *const c_void {
*self.fn_ptrs.offset(idx as isize)
}

pub fn get_ptr(&self, idx: u32) -> Option<*const c_void> {
if idx < self.num_entries {
Some(unsafe { self.get_ptr_unchecked(idx) })
} else {
None
}
}

pub unsafe fn set_ptr_unchecked(&mut self, idx: u32, ptr: *const c_void) {
*self.fn_ptrs.offset(idx as isize) = ptr;
}

pub fn set_ptr(&mut self, idx: u32, ptr: *const c_void) -> bool {
if idx < self.num_entries {
unsafe { self.set_ptr_unchecked(idx, ptr) };
true
} else {
false
}
}
}

impl AssemblyInfo {
pub fn dependencies(&self) -> impl Iterator<Item = &str> {
let dependencies = if self.num_dependencies == 0 {
&[]
} else {
unsafe { slice::from_raw_parts(self.dependencies, self.num_dependencies as usize) }
};

dependencies.iter().map(|d| {
unsafe { CStr::from_ptr(*d) }
.to_str()
.expect("dependency path contains invalid UTF8")
})
}
}
2 changes: 1 addition & 1 deletion crates/mun_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod prelude {
}

#[repr(u8)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Privacy {
Public = 0,
Private = 1,
Expand Down
6 changes: 3 additions & 3 deletions crates/mun_abi/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ macro_rules! downcast_fn {
($FunctionInfo:expr, fn($($T:ident),*) -> $Output:ident) => {{
let num_args = $crate::count_args!($($T),*);

let arg_types = $FunctionInfo.arg_types();
let arg_types = $FunctionInfo.signature.arg_types();
if arg_types.len() != num_args {
return Err(format!(
"Invalid number of arguments. Expected: {}. Found: {}.",
Expand All @@ -38,7 +38,7 @@ macro_rules! downcast_fn {
idx += 1;
)*

if let Some(return_type) = $FunctionInfo.return_type() {
if let Some(return_type) = $FunctionInfo.signature.return_type() {
if return_type.guid != Output::type_guid() {
return Err(format!(
"Invalid return type. Expected: {}. Found: {}",
Expand All @@ -54,6 +54,6 @@ macro_rules! downcast_fn {
));
}

Ok(unsafe { core::mem::transmute($FunctionInfo.fn_ptr()) })
Ok(unsafe { core::mem::transmute($FunctionInfo.fn_ptr) })
}}
}
86 changes: 54 additions & 32 deletions crates/mun_runtime/src/assembly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@ use std::fs;
use std::io;
use std::path::{Path, PathBuf};

use crate::library::Library;
use crate::DispatchTable;
use failure::Error;
use mun_abi::FunctionInfo;
use libloading::{Library, Symbol};
use mun_abi::AssemblyInfo;

const LIB_DIR: &str = "tmp";

/// An assembly is the smallest compilable unit of code in Mun.
pub struct Assembly {
library_path: PathBuf,
tmp_path: PathBuf,
library: Option<Library>,
info: AssemblyInfo,
}

impl Assembly {
/// Loads an assembly for the library at `library_path` and its dependencies.
pub fn load(library_path: &Path) -> Result<Self, Error> {
let library_name = library_path.file_name().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "Incorrect library path.")
})?;
pub fn load(
library_path: &Path,
runtime_dispatch_table: &mut DispatchTable,
) -> Result<Self, Error> {
let library_name = library_path.file_name().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Incorrect library path.",
))?;

let tmp_dir = env::current_dir()?.join(LIB_DIR);
if !tmp_dir.exists() {
Expand All @@ -31,48 +36,65 @@ impl Assembly {
let tmp_path = tmp_dir.join(library_name);
fs::copy(&library_path, &tmp_path)?;

let library = Library::new(tmp_path.as_path())?;
println!("Loaded module '{}'.", library_path.to_string_lossy());
let library = Library::new(&tmp_path)?;

// Check whether the library has a symbols function
let get_info: Symbol<'_, extern "C" fn() -> AssemblyInfo> =
unsafe { library.get(b"get_info") }?;

let info = get_info();

for function in info.symbols.functions() {
runtime_dispatch_table.insert(function.signature.name(), function.clone());
}

Ok(Assembly {
library_path: library_path.to_path_buf(),
tmp_path,
library: Some(library),
info,
})
}

pub fn swap(&mut self, library_path: &Path) -> Result<(), Error> {
let library_path = library_path.canonicalize()?;
let library_name = library_path.file_name().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Incorrect library path.",
))?;
pub fn link(&mut self, runtime_dispatch_table: &DispatchTable) -> Result<(), Error> {
for (dispatch_ptr, fn_signature) in self.info.dispatch_table.iter_mut() {
let fn_ptr = runtime_dispatch_table
.get(fn_signature.name())
.map(|f| f.fn_ptr)
.ok_or(io::Error::new(
io::ErrorKind::NotFound,
format!(
"Failed to link: function '{}' is missing.",
fn_signature.name()
),
))?;

*dispatch_ptr = fn_ptr;
}
Ok(())
}

let tmp_path = env::current_dir()?.join(LIB_DIR).join(library_name);
pub fn swap(
&mut self,
library_path: &Path,
runtime_dispatch_table: &mut DispatchTable,
) -> Result<(), Error> {
// let library_path = library_path.canonicalize()?;

// Drop the old library, as some operating systems don't allow editing of in-use shared libraries
self.library.take();

fs::copy(&library_path, &tmp_path)?;

let library = Library::new(tmp_path.as_path())?;
println!("Reloaded module '{}'.", library_path.to_string_lossy());

self.library = Some(library);
self.library_path = library_path;
self.tmp_path = tmp_path;
for function in self.info.symbols.functions() {
runtime_dispatch_table.remove(function.signature.name());
}

// TODO: Partial hot reload of an assembly
*self = Assembly::load(library_path, runtime_dispatch_table)?;
Ok(())
}

/// Retrieves the assembly's loaded shared library.
pub fn library(&self) -> &Library {
self.library.as_ref().expect("Library was not loaded.")
}

/// Retrieves all of the assembly's functions.
pub fn functions(&self) -> impl Iterator<Item = &FunctionInfo> {
self.library().module_info().functions().iter()
pub fn info(&self) -> &AssemblyInfo {
&self.info
}

/// Returns the path corresponding tot the assembly's library.
Expand Down
Loading