From 7c391c19ff298d3e21edfb5b21950c2df3fdb481 Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Fri, 13 Dec 2024 17:48:31 +0100 Subject: [PATCH] Added string --- src/MLX.jl | 3 ++- src/string.jl | 27 +++++++++++++++++++++++++++ test/runtests.jl | 1 + test/string_tests.jl | 10 ++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 src/string.jl create mode 100644 test/string_tests.jl diff --git a/src/MLX.jl b/src/MLX.jl index b251c22..4f6832b 100644 --- a/src/MLX.jl +++ b/src/MLX.jl @@ -1,6 +1,6 @@ module MLX -export MLXArray, MLXException, MLXMatrix, MLXVecOrMat, MLXVector +export MLXArray, MLXException, MLXMatrix, MLXString, MLXVecOrMat, MLXVector include(joinpath(@__DIR__, "Wrapper.jl")) @@ -9,6 +9,7 @@ include(joinpath(@__DIR__, "device.jl")) include(joinpath(@__DIR__, "error_handling.jl")) include(joinpath(@__DIR__, "metal.jl")) include(joinpath(@__DIR__, "stream.jl")) +include(joinpath(@__DIR__, "string.jl")) function __init__() register_error_handler() diff --git a/src/string.jl b/src/string.jl new file mode 100644 index 0000000..2337767 --- /dev/null +++ b/src/string.jl @@ -0,0 +1,27 @@ +mutable struct MLXString <: AbstractString + mlx_string::Wrapper.mlx_string + + function MLXString(ptr::Ptr{UInt8}) + mlx_string = Wrapper.mlx_string_new_data(ptr) + this = new(mlx_string) + finalizer(d -> Wrapper.mlx_string_free(d.mlx_string), this) + return this + end +end + +MLXString(str::AbstractString) = MLXString(pointer(str)) + +function Base.ncodeunits(str::MLXString) + return length(unsafe_string(Wrapper.mlx_string_data(str.mlx_string))) +end + +function Base.iterate(str::MLXString) + return Base.iterate(unsafe_string(Wrapper.mlx_string_data(str.mlx_string))) +end +function Base.iterate(str::MLXString, i::Int) + return Base.iterate(unsafe_string(Wrapper.mlx_string_data(str.mlx_string)), i) +end + +function convert(::Type{String}, str::MLXString) + return unsafe_string(Wrapper.mlx_string_data(str.mlx_string)) +end diff --git a/test/runtests.jl b/test/runtests.jl index b5b5399..c2923aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ using Test if !Sys.iswindows() # Windows is hanging include(joinpath(@__DIR__, "stream_tests.jl")) end + include(joinpath(@__DIR__, "string_tests.jl")) end diff --git a/test/string_tests.jl b/test/string_tests.jl new file mode 100644 index 0000000..e618700 --- /dev/null +++ b/test/string_tests.jl @@ -0,0 +1,10 @@ +using MLX +using Test + +@testset "string" begin + @testset "constructor" begin + s = "" + s_mlx = MLXString(s) + @test s_mlx == s + end +end