Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/mun_abi/src/autogen_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ impl StructInfo {
unsafe { slice::from_raw_parts(self.field_sizes, self.num_fields as usize) }
}
}

/// Returns the index of the field matching the specified `field_name`.
pub fn find_field_index(struct_info: &StructInfo, field_name: &str) -> Result<usize, String> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at some point we should figure out a way to do some sort of hash lookup. Because this might become slow.

struct_info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
struct_info.name(),
field_name
)
})
}
}

impl fmt::Display for StructInfo {
Expand Down
54 changes: 20 additions & 34 deletions crates/mun_runtime/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,53 +105,39 @@ macro_rules! invoke_fn_impl {
if arg_types.len() != num_args {
return Err(format!(
"Invalid number of arguments. Expected: {}. Found: {}.",
num_args,
arg_types.len(),
num_args,
));
}

#[allow(unused_mut, unused_variables)]
let mut idx = 0;
$(
if arg_types[idx].guid != $Arg.type_guid() {
return Err(format!(
"Invalid argument type at index {}. Expected: {}. Found: {}.",
idx,
$Arg.type_name(),
arg_types[idx].name(),
));
}
crate::reflection::equals_argument_type(&arg_types[idx], &$Arg)
.map_err(|(expected, found)| {
format!(
"Invalid argument type at index {}. Expected: {}. Found: {}.",
idx,
expected,
found,
)
})?;
idx += 1;
)*

if let Some(return_type) = function_info.signature.return_type() {
match return_type.group {
abi::TypeGroup::FundamentalTypes => {
if return_type.guid != Output::type_guid() {
return Err(format!(
"Invalid return type. Expected: {}. Found: {}",
Output::type_name(),
return_type.name(),
));
}
}
abi::TypeGroup::StructTypes => {
if <Struct as ReturnTypeReflection>::type_guid() != Output::type_guid() {
return Err(format!(
"Invalid return type. Expected: {}. Found: Struct",
Output::type_name(),
));
}
}
}

crate::reflection::equals_return_type::<Output>(return_type)
} else if <() as ReturnTypeReflection>::type_guid() != Output::type_guid() {
return Err(format!(
Err((<() as ReturnTypeReflection>::type_name(), Output::type_name()))
} else {
Ok(())
}.map_err(|(expected, found)| {
format!(
"Invalid return type. Expected: {}. Found: {}",
Output::type_name(),
<() as ReturnTypeReflection>::type_name(),
));
}
expected,
found,
)
})?;

Ok(function_info)
}) {
Expand Down
39 changes: 35 additions & 4 deletions crates/mun_runtime/src/reflection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
use crate::marshal::MarshalInto;
use abi::Guid;
use crate::{marshal::MarshalInto, Struct};
use abi::{Guid, TypeInfo};
use md5;

/// Returns whether the specified argument type matches the `type_info`.
pub fn equals_argument_type<'e, 'f, T: ArgumentReflection>(
type_info: &'e TypeInfo,
arg: &'f T,
) -> Result<(), (&'e str, &'f str)> {
if type_info.guid != arg.type_guid() {
Err((type_info.name(), arg.type_name()))
} else {
Ok(())
}
}

/// Returns whether the specified return type matches the `type_info`.
pub fn equals_return_type<T: ReturnTypeReflection>(
type_info: &TypeInfo,
) -> Result<(), (&str, &str)> {
match type_info.group {
abi::TypeGroup::FundamentalTypes => {
if type_info.guid != T::type_guid() {
return Err((type_info.name(), T::type_name()));
}
}
abi::TypeGroup::StructTypes => {
if <Struct as ReturnTypeReflection>::type_guid() != T::type_guid() {
return Err(("struct", T::type_name()));
}
}
}
Ok(())
}

/// A type to emulate dynamic typing across compilation units for static types.
pub trait ReturnTypeReflection: Sized + 'static {
/// The resulting type after marshaling.
Expand All @@ -19,9 +50,9 @@ pub trait ReturnTypeReflection: Sized + 'static {
}

/// A type to emulate dynamic typing across compilation units for statically typed values.
pub trait ArgumentReflection {
pub trait ArgumentReflection: Sized {
/// The resulting type after dereferencing.
type Marshalled: Sized;
type Marshalled: MarshalInto<Self>;

/// Retrieves the `Guid` of the value's type.
fn type_guid(&self) -> Guid {
Expand Down
112 changes: 42 additions & 70 deletions crates/mun_runtime/src/struct.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
marshal::MarshalInto,
reflection::{ArgumentReflection, ReturnTypeReflection},
reflection::{
equals_argument_type, equals_return_type, ArgumentReflection, ReturnTypeReflection,
},
};
use abi::{StructInfo, TypeInfo};
use std::mem;
Expand Down Expand Up @@ -39,114 +41,84 @@ impl Struct {
}

/// Retrieves the value of the field corresponding to the specified `field_name`.
pub fn get<T: ReturnTypeReflection>(&self, field_name: &str) -> Result<&T, String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

pub fn get<T: ReturnTypeReflection>(&self, field_name: &str) -> Result<T, String> {
let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if T::type_guid() != field_type.guid {
return Err(format!(
equals_return_type::<T>(&field_type).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
T::type_name()
));
}
expected,
found,
)
})?;

unsafe {
let field_value = unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
Ok(&*self.raw.0.add(offset as usize).cast::<T>())
}
// TODO: The unsafe `read` fn could be avoided by adding the `Clone` bound on
// `T::Marshalled`, but its only available on nightly:
// `ReturnTypeReflection<Marshalled: Clone>`
self.raw
.0
.add(offset as usize)
.cast::<T::Marshalled>()
.read()
};
Ok(field_value.marshal_into(Some(*field_type)))
}

/// Replaces the value of the field corresponding to the specified `field_name` and returns the
/// old value.
pub fn replace<T: ArgumentReflection>(
&mut self,
field_name: &str,
mut value: T,
value: T,
) -> Result<T, String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if value.type_guid() != field_type.guid {
return Err(format!(
equals_argument_type(&field_type, &value).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
value.type_name()
));
}
expected,
found,
)
})?;

let mut marshalled: T::Marshalled = value.marshal();
let ptr = unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
&mut *self.raw.0.add(offset as usize).cast::<T>()
&mut *self.raw.0.add(offset as usize).cast::<T::Marshalled>()
};
mem::swap(&mut value, ptr);
Ok(value)
mem::swap(&mut marshalled, ptr);
Ok(marshalled.marshal_into(Some(*field_type)))
}

/// Sets the value of the field corresponding to the specified `field_name`.
pub fn set<T: ArgumentReflection>(&mut self, field_name: &str, value: T) -> Result<(), String> {
let field_idx = self
.info
.field_names()
.enumerate()
.find(|(_, name)| *name == field_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
self.info.name(),
field_name
)
})?;

let field_idx = StructInfo::find_field_index(&self.info, field_name)?;
let field_type = unsafe { self.info.field_types().get_unchecked(field_idx) };
if value.type_guid() != field_type.guid {
return Err(format!(
equals_argument_type(&field_type, &value).map_err(|(expected, found)| {
format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
self.info.name(),
field_name,
field_type.name(),
value.type_name()
));
}
expected,
found,
)
})?;

unsafe {
// If we found the `field_idx`, we are guaranteed to also have the `field_offset`
let offset = *self.info.field_offsets().get_unchecked(field_idx);
// self.ptr is never null
*self.raw.0.add(offset as usize).cast::<T>() = value;
*self.raw.0.add(offset as usize).cast::<T::Marshalled>() = value.marshal();
}
Ok(())
}
Expand Down
51 changes: 41 additions & 10 deletions crates/mun_runtime/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,15 @@ fn marshal_struct() {
let mut driver = TestDriver::new(
r#"
struct(gc) Foo { a: int, b: bool, c: float, };
struct Bar(Foo);

fn foo_new(a: int, b: bool, c: float): Foo {
Foo { a, b, c, }
}
fn bar_new(foo: Foo): Bar {
Bar(foo)
}

fn foo_a(foo: Foo):int { foo.a }
fn foo_b(foo: Foo):bool { foo.b }
fn foo_c(foo: Foo):float { foo.c }
Expand All @@ -401,9 +406,9 @@ fn marshal_struct() {
let b = true;
let c = 1.23f64;
let mut foo: Struct = invoke_fn!(driver.runtime, "foo_new", a, b, c).unwrap();
assert_eq!(Ok(&a), foo.get::<i64>("a"));
assert_eq!(Ok(&b), foo.get::<bool>("b"));
assert_eq!(Ok(&c), foo.get::<f64>("c"));
assert_eq!(Ok(a), foo.get::<i64>("a"));
assert_eq!(Ok(b), foo.get::<bool>("b"));
assert_eq!(Ok(c), foo.get::<f64>("c"));

let d = 6i64;
let e = false;
Expand All @@ -412,19 +417,45 @@ fn marshal_struct() {
foo.set("b", e).unwrap();
foo.set("c", f).unwrap();

assert_eq!(Ok(&d), foo.get::<i64>("a"));
assert_eq!(Ok(&e), foo.get::<bool>("b"));
assert_eq!(Ok(&f), foo.get::<f64>("c"));
assert_eq!(Ok(d), foo.get::<i64>("a"));
assert_eq!(Ok(e), foo.get::<bool>("b"));
assert_eq!(Ok(f), foo.get::<f64>("c"));

assert_eq!(Ok(d), foo.replace("a", a));
assert_eq!(Ok(e), foo.replace("b", b));
assert_eq!(Ok(f), foo.replace("c", c));

assert_eq!(Ok(&a), foo.get::<i64>("a"));
assert_eq!(Ok(&b), foo.get::<bool>("b"));
assert_eq!(Ok(&c), foo.get::<f64>("c"));
assert_eq!(Ok(a), foo.get::<i64>("a"));
assert_eq!(Ok(b), foo.get::<bool>("b"));
assert_eq!(Ok(c), foo.get::<f64>("c"));

assert_invoke_eq!(i64, a, driver, "foo_a", foo.clone());
assert_invoke_eq!(bool, b, driver, "foo_b", foo.clone());
assert_invoke_eq!(f64, c, driver, "foo_c", foo);
assert_invoke_eq!(f64, c, driver, "foo_c", foo.clone());

let mut bar: Struct = invoke_fn!(driver.runtime, "bar_new", foo.clone()).unwrap();
let foo2 = bar.get::<Struct>("0").unwrap();
assert_eq!(Ok(a), foo2.get::<i64>("a"));
assert_eq!(foo2.get::<bool>("b"), foo.get::<bool>("b"));
assert_eq!(foo2.get::<f64>("c"), foo.get::<f64>("c"));

// Specify invalid return type
let bar_err = bar.get::<i64>("0");
assert!(bar_err.is_err());

// Specify invalid argument type
let bar_err = bar.replace("0", 1i64);
assert!(bar_err.is_err());

// Specify invalid argument type
let bar_err = bar.set("0", 1i64);
assert!(bar_err.is_err());

// Specify invalid return type
let bar_err: Result<i64, _> = invoke_fn!(driver.runtime, "bar_new", foo);
assert!(bar_err.is_err());

// Pass invalid struct type
let bar_err: Result<Struct, _> = invoke_fn!(driver.runtime, "bar_new", bar);
assert!(bar_err.is_err());
}