Skip to content
5 changes: 2 additions & 3 deletions gtwrap/interface_parser/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_parse_result(parse_result: ParseResults):
return ArgumentList([])

def __repr__(self) -> str:
return ",".join([repr(x) for x in self.args_list])
return ", ".join([repr(x) for x in self.args_list])

def __len__(self) -> int:
return len(self.args_list)
Expand Down Expand Up @@ -182,8 +182,7 @@ def __init__(self,
self.args.parent = self

def __repr__(self) -> str:
return "GlobalFunction: {}{}({})".format(self.return_type, self.name,
self.args)
return f"GlobalFunction: {self.name}({self.args}) -> {self.return_type}"

def to_cpp(self) -> str:
"""Generate the C++ code for wrapping."""
Expand Down
2 changes: 1 addition & 1 deletion gtwrap/interface_parser/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TypenameAndInstantiations:
"""
Rule to parse the template parameters.

template<typename POSE> // POSE is the Instantiation.
template<typename POSE = {Pose2, Pose3}> // Pos2 and Pose3 are the `Instantiation`s.
"""
rule = (
IDENT("typename") #
Expand Down
39 changes: 23 additions & 16 deletions gtwrap/interface_parser/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class Typename:
"""

namespaces_name_rule = delimitedList(IDENT, "::")
instantiation_name_rule = delimitedList(IDENT, "::")
rule = (
namespaces_name_rule("namespaces_and_name") #
).setParseAction(lambda t: Typename(t))
Expand Down Expand Up @@ -164,7 +163,7 @@ class Type:
"""
rule = (
Optional(CONST("is_const")) #
+ (BasicType.rule("basic") | CustomType.rule("qualified")) # BR
+ (BasicType.rule("basic") | CustomType.rule("custom")) # BR
+ Optional(
SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr")
| REF("is_ref")) #
Expand Down Expand Up @@ -192,9 +191,9 @@ def from_parse_result(t: ParseResults):
is_ref=t.is_ref,
is_basic=True,
)
elif t.qualified:
elif t.custom:
return Type(
typename=t.qualified.typename,
typename=t.custom.typename,
is_const=t.is_const,
is_shared_ptr=t.is_shared_ptr,
is_ptr=t.is_ptr,
Expand All @@ -212,6 +211,13 @@ def __repr__(self) -> str:
is_const="const " if self.is_const else "",
is_ptr_or_ref=" " + is_ptr_or_ref if is_ptr_or_ref else "")

def get_typename(self):
"""
Get the typename of this type without any qualifiers.
E.g. for `const gtsam::Pose3& pose` this will return `gtsam::Pose3`.
"""
return self.typename.to_cpp()

def to_cpp(self) -> str:
"""
Generate the C++ code for wrapping.
Expand All @@ -221,22 +227,18 @@ def to_cpp(self) -> str:

if self.is_shared_ptr:
typename = "std::shared_ptr<{typename}>".format(
typename=self.typename.to_cpp())
typename=self.get_typename())
elif self.is_ptr:
typename = "{typename}*".format(typename=self.typename.to_cpp())
elif self.is_ref:
typename = typename = "{typename}&".format(
typename=self.typename.to_cpp())
typename=self.get_typename())
else:
typename = self.typename.to_cpp()
typename = self.get_typename()

return ("{const}{typename}".format(
const="const " if self.is_const else "", typename=typename))

def get_typename(self):
"""Convenience method to get the typename of this type."""
return self.typename.name


class TemplatedType:
"""
Expand Down Expand Up @@ -283,16 +285,21 @@ def __repr__(self):
return "TemplatedType({typename.namespaces}::{typename.name})".format(
typename=self.typename)

def to_cpp(self):
def get_typename(self):
"""
Generate the C++ code for wrapping.
Get the typename of this type without any qualifiers.
E.g. for `const std::vector<double>& indices` this will return `std::vector<double>`.
"""
# Use Type.to_cpp to do the heavy lifting for the template parameters.
template_args = ", ".join([t.to_cpp() for t in self.template_params])

typename = "{typename}<{template_args}>".format(
typename=self.typename.qualified_name(),
template_args=template_args)
return f"{self.typename.qualified_name()}<{template_args}>"

def to_cpp(self):
"""
Generate the C++ code for wrapping.
"""
typename = self.get_typename()

if self.is_shared_ptr:
typename = f"std::shared_ptr<{typename}>"
Expand Down
17 changes: 15 additions & 2 deletions gtwrap/matlab_wrapper/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
class CheckMixin:
"""Mixin to provide various checks."""
# Data types that are primitive types
not_ptr_type: Tuple = ('int', 'double', 'bool', 'char', 'unsigned char',
'size_t')
not_ptr_type: Tuple = (
"int",
"double",
"bool",
"char",
"unsigned char",
"size_t",
"Key", # This is an alias for a uint64_t
)
# Ignore the namespace for these datatypes
ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3')
# Methods that should be ignored
Expand Down Expand Up @@ -111,6 +118,9 @@ def _format_type_name(self,
is_constructor: bool = False,
is_method: bool = False):
"""
Helper method to get the string version of `type_name` which can go into the wrapper generated C++ code.
This is specific to the semantics of Matlab.

Args:
type_name: an interface_parser.Typename to reformat
separator: the statement to add between namespaces and typename
Expand All @@ -133,6 +143,9 @@ def _format_type_name(self,
if name not in self.ignore_namespace and namespace != '':
formatted_type_name += namespace + separator

# Get string representation so we can use as dict key.
name = str(name)

if is_constructor:
formatted_type_name += self.data_type.get(name) or name
elif is_method:
Expand Down
20 changes: 12 additions & 8 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
'Matrix': 'double',
'int': 'numeric',
'size_t': 'numeric',
'Key': 'numeric',
'bool': 'logical'
}
# Map the data type into the type used in Matlab methods.
Expand All @@ -68,6 +69,7 @@ def __init__(self,
'Point3': 'double',
'Vector': 'double',
'Matrix': 'double',
'Key': 'numeric',
'bool': 'bool'
}
# The amount of times the wrapper has created a call to geometry_wrapper
Expand Down Expand Up @@ -108,7 +110,8 @@ def _update_wrapper_id(self,

Args:
collector_function: tuple storing info about the wrapper function
(namespace, class instance, function name, function object)
(namespace, class/function instance,
type of collector function, method object if class instance)
id_diff: constant to add to the id in the map
function_name: Optional custom function_name.

Expand Down Expand Up @@ -372,9 +375,9 @@ def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id)

else:
arg_type = "{ctype}".format(ctype=arg.ctype.typename.name)
arg_type = "{ctype}".format(ctype=self._format_type_name(arg.ctype.typename))
unwrap = 'unwrap< {ctype} >(in[{id}]);'.format(
ctype=arg.ctype.typename.name, id=arg_id)
ctype=self._format_type_name(arg.ctype.typename), id=arg_id)

return arg_type, unwrap

Expand Down Expand Up @@ -578,6 +581,7 @@ def wrap_global_function(self, function):
# Get all combinations of parameters
param_wrap = ''

# Iterate through possible overloads of the function
for i, overload in enumerate(function):
param_wrap += ' if' if i == 0 else ' elseif'
param_wrap += ' length(varargin) == '
Expand Down Expand Up @@ -1218,7 +1222,7 @@ def wrap_namespace(self, namespace, add_mex_file=True):
if isinstance(func, parser.GlobalFunction)
]

self.wrap_methods(all_funcs, True, global_ns=namespace)
self.wrap_methods(all_funcs, global_funcs=True, global_ns=namespace)

return wrapped

Expand Down Expand Up @@ -1333,7 +1337,7 @@ def _collector_return(self,
prefix=' ')
else:
expanded += ' out[0] = wrap< {0} >({1});'.format(
ctype.typename.name, obj)
self._format_type_name(ctype.typename), obj)

return expanded

Expand Down Expand Up @@ -1365,8 +1369,8 @@ def wrap_collector_function_return(self, method, instantiated_class=None):
method_name += method.original.name

elif isinstance(method, parser.GlobalFunction):
method_name = self._format_global_function(method, '::')
method_name += method.name
namespace = self._format_global_function(method, '::')
method_name = namespace + method.to_cpp()

else:
if isinstance(method.parent, instantiator.InstantiatedClass):
Expand Down Expand Up @@ -1624,7 +1628,7 @@ def generate_collector_function(self, func_id):

body += self._wrapper_unwrap_arguments(collector_func[1].args)[1]
body += self.wrap_collector_function_return(
collector_func[1]) + '\n}\n'
collector_func[1]) + "\n}\n"

collector_function += body

Expand Down
9 changes: 4 additions & 5 deletions gtwrap/template_instantiator/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class InstantiatedGlobalFunction(parser.GlobalFunction):
template<T = {double}>
T add(const T& x, const T& y);
"""

def __init__(self, original, instantiations=(), new_name=''):
self.original = original
self.instantiations = instantiations
Expand Down Expand Up @@ -54,16 +55,14 @@ def __init__(self, original, instantiations=(), new_name=''):
def to_cpp(self):
"""Generate the C++ code for wrapping."""
if self.original.template:
instantiated_names = [
instantiated_params = [
"::".join(inst.namespaces + [inst.instantiated_name()])
for inst in self.instantiations
]
ret = "{}<{}>".format(self.original.name,
",".join(instantiated_names))
ret = f"{self.original.name}<{','.join(instantiated_params)}>"
else:
ret = self.original.name
return ret

def __repr__(self):
return "Instantiated {}".format(
super(InstantiatedGlobalFunction, self).__repr__())
return f"Instantiated {super().__repr__}"
3 changes: 2 additions & 1 deletion gtwrap/template_instantiator/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def instantiate_type(
ctype.typename.instantiations[idx].name =\
instantiations[template_idx]


str_arg_typename = str(ctype.typename)

# Check if template is a scoped template e.g. T::Value where T is the template
Expand All @@ -88,6 +87,7 @@ def instantiate_type(
is_ref=ctype.is_ref,
is_basic=ctype.is_basic,
)

# Check for exact template match.
elif str_arg_typename in template_typenames:
idx = template_typenames.index(str_arg_typename)
Expand Down Expand Up @@ -228,6 +228,7 @@ class InstantiationHelper:
parent=parent)
```
"""

def __init__(self, instantiation_type: InstantiatedMembers):
self.instantiation_type = instantiation_type

Expand Down
15 changes: 15 additions & 0 deletions matlab.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ mxArray* wrap<int>(const int& value) {
return result;
}

// specialization to gtsam::Key which is an alias for uint64_t
template<>
mxArray* wrap<uint64_t>(const uint64_t& value) {
mxArray *result = scalar(mxUINT32OR64_CLASS);
*(uint64_t*)mxGetData(result) = value;
return result;
}

// specialization to double -> just double
template<>
mxArray* wrap<double>(const double& value) {
Expand Down Expand Up @@ -330,6 +338,13 @@ int unwrap<int>(const mxArray* array) {
return myGetScalar<int>(array);
}

// specialization to gtsam::Key which is an alias for uint64_t
template<>
uint64_t unwrap<uint64_t>(const mxArray* array) {
checkScalar(array,"unwrap<uint64_t>");
return myGetScalar<uint64_t>(array);
}

// specialization to size_t
template<>
size_t unwrap<size_t>(const mxArray* array) {
Expand Down
7 changes: 7 additions & 0 deletions tests/expected/matlab/EliminateDiscrete.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = EliminateDiscrete(varargin)
if length(varargin) == 2 && isa(varargin{1},'gtsam.DiscreteFactorGraph') && isa(varargin{2},'gtsam.Ordering')
[ varargout{1} varargout{2} ] = functions_wrapper(25, varargin{:});
else
error('Arguments do not match any overload of function EliminateDiscrete');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanPoint3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanPoint3(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Point3')
varargout{1} = functions_wrapper(28, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanPoint3');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanPose2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanPose2(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Pose2')
varargout{1} = functions_wrapper(30, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanPose2');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanPose3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanPose3(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Pose3')
varargout{1} = functions_wrapper(31, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanPose3');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanRot2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanRot2(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Rot2')
varargout{1} = functions_wrapper(29, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanRot2');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanRot3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanRot3(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Rot3')
varargout{1} = functions_wrapper(31, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanRot3');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanSO3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanSO3(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::SO3')
varargout{1} = functions_wrapper(29, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanSO3');
end
end
7 changes: 7 additions & 0 deletions tests/expected/matlab/FindKarcherMeanSO4.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function varargout = FindKarcherMeanSO4(varargin)
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::SO4')
varargout{1} = functions_wrapper(30, varargin{:});
else
error('Arguments do not match any overload of function FindKarcherMeanSO4');
end
end
Loading
Loading