diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 8645a85ff..1854a79b6 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- #v1 Add a fresh implementation of `StringLeafHandler` to reduce unnecessary +delegation complexity. Also introduce a `typestr` function for `LeafHandler`. + ## [0.11.31] - 2025-12-11 ### Added diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index 97f2f05f8..c85206eda 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -311,3 +311,6 @@ async def _convert_to_array_metadata() -> Sequence[ArrayMetadata]: return ret return await _convert_to_array_metadata() + + def typestr(self) -> str: + return 'jax.Array' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py index b5142dc0b..0d46e0f68 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py @@ -428,19 +428,6 @@ def get_v0_type_handler_registry( type handler registry. context: The Context to be used to default construct the LeafHandlers. """ - - def _get_typestr(leaf_type: Any) -> str: - if leaf_type == jax.Array: - return type_handlers_v0.JAX_ARRAY_TYPE_STR - elif leaf_type == np.ndarray: - return 'np.ndarray' - elif leaf_type in (int, float, bytes, np.number): - return 'scalar' - elif leaf_type == str: - return 'string' - else: - return f'{leaf_type!r}' - # register standardard v1 leaf handlers to the v0 type handler registry. handlers = [] for leaf_type, _, leaf_handler_type in leaf_handler_registry.get_all(): @@ -455,7 +442,7 @@ def _get_typestr(leaf_type: Any) -> str: leaf_type, CompatibleTypeHandler( leaf_handler, - typestr=_get_typestr(leaf_type), + typestr=leaf_handler.typestr(), ), )) return type_handler_registry.create_type_handler_registry(*handlers) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index 86c1c7bd5..4683c6001 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -265,3 +265,6 @@ async def _convert_to_numpy_metadata() -> Sequence[NumpyMetadata]: return ret return await _convert_to_numpy_metadata() + + def typestr(self) -> str: + return 'np.ndarray' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index fa56ad9ec..87d515a9d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -251,3 +251,6 @@ def _get_type(meta: type_handlers_v0.ScalarMetadata): return ret return await _convert_to_scalar_metadata() + + def typestr(self) -> str: + return "scalar" diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py index 6938e7bfb..a468d4bd5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py @@ -19,13 +19,16 @@ """ import asyncio -from typing import Awaitable, Sequence, Type +from typing import Any, Awaitable, Sequence from absl import logging -from orbax.checkpoint._src.futures import future -from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost +import tensorstore as ts + AbstractString = types.AbstractString StringSerializationParam = types.SerializationParam[str] @@ -34,54 +37,6 @@ ] -def _create_v0_saving_paraminfo( - param: StringSerializationParam, - context: context_lib.Context, - serialization_context: types.SerializationContext, -) -> type_handlers_v0.ParamInfo: - """Creates a V0 ParamInfo from V1 params and contexts for saving.""" - - saving_options = context.array_options.saving - - return type_handlers_v0.ParamInfo( - name=param.name, - parent_dir=serialization_context.parent_dir.path, - byte_limiter=serialization_context.byte_limiter, - is_ocdbt_checkpoint=saving_options.use_ocdbt, - use_zarr3=saving_options.use_zarr3, - ocdbt_target_data_file_size=saving_options.ocdbt_target_data_file_size, - ts_context=serialization_context.ts_context, - value_typestr="string", - ) - - -def _create_v0_restore_paraminfo( - param: types.DeserializationParam[ - AbstractString | Type[AbstractString] | None - ], - context: context_lib.Context, - deserialization_context: types.DeserializationContext, -) -> type_handlers_v0.ParamInfo: - """Creates a V0 ParamInfo from V1 params and contexts for loading.""" - - loading_options = context.array_options.loading - - return type_handlers_v0.ParamInfo( - name=param.name, - parent_dir=deserialization_context.parent_dir, - skip_deserialize=False, - byte_limiter=deserialization_context.byte_limiter, - is_ocdbt_checkpoint=deserialization_context.ocdbt_checkpoint, - ts_context=deserialization_context.ts_context, - raise_array_data_missing_error=loading_options.raise_array_data_missing_error, - use_zarr3=deserialization_context.zarr3_checkpoint, - ) - - -async def _async_futures(commit_futures: Sequence[future.Future]): - await asyncio.gather(*[asyncio.to_thread(f.result) for f in commit_futures]) - - class StringLeafHandler(types.LeafHandler[str, AbstractString]): """:py:class:`.StringLeafHandler` that implements the :py:class:`~.v1.serialization.LeafHandler` Protocol.""" @@ -98,9 +53,46 @@ def __init__( context: Context that will be used for this leaf handler. """ self._context = context_lib.get_context(context) - self._handler_impl = type_handlers_v0.StringHandler() + self._filename = '_strings.json' + logging.vlog(1, 'StringLeafHandler created.') - logging.vlog(1, "StringLeafHandler created.") + def _get_json_tspec( + self, + param_name: str, + parent_dir: path_types.Path, + ) -> dict[str, Any]: + """Gets Tensorstore spec in JSON format.""" + directory = (parent_dir / self._filename).as_posix() + kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False) + tspec = { + 'driver': 'json', + 'kvstore': kvstore_tspec, + 'json_pointer': '/' + param_name, + } + return tspec + + async def _background_serialize( + self, + params: Sequence[StringSerializationParam], + serialization_context: types.SerializationContext, + ): + """Writes strings using Tensorstore in the background thread.""" + parent_dir = await serialization_context.parent_dir.await_creation() + write_coros = [] + txn = ts.Transaction() + for param in params: + tspec = self._get_json_tspec(param.name, parent_dir) + if multihost.is_primary_host( + self._context.multiprocessing_options.primary_host + ): + t = await ts.open( + tspec, + open=True, + context=serialization_context.ts_context, + ) + write_coros.append(t.with_transaction(txn).write(param.value)) # pytype: disable=attribute-error + await asyncio.gather(*write_coros) + await txn.commit_async() async def serialize( self, @@ -117,18 +109,34 @@ async def serialize( Sequence of commit futures which can be awaited to complete the save operation. """ - values = [p.value for p in params] - paraminfos = [ - _create_v0_saving_paraminfo(p, self._context, serialization_context) - for p in params - ] - - # `args` is not used by StringHandler.serialize, so it's not passed in. - commit_futures = await self._handler_impl.serialize(values, paraminfos) - if not commit_futures: - raise ValueError("No commit futures returned by StringHandler.serialize.") + return self._background_serialize(params, serialization_context) - return _async_futures(commit_futures) + async def _background_deserialize( + self, + params: Sequence[StringDeserializationParam], + deserialization_context: types.DeserializationContext, + ) -> Sequence[str]: + """Deserializes strings using Tensorstore in the background thread.""" + + async def _convert_to_string(tensorstore): + result = await tensorstore.read() + return str(result) + + open_futures = [] + for param in params: + tspec = self._get_json_tspec( + param.name, deserialization_context.parent_dir + ) + open_future = ts.open( + tspec, + open=True, + read=True, + context=deserialization_context.ts_context, + ) + open_futures += [open_future] + tensorstores = await asyncio.gather(*open_futures) + read_ops = [_convert_to_string(t) for t in tensorstores] + return await asyncio.gather(*read_ops) async def deserialize( self, @@ -145,20 +153,7 @@ async def deserialize( Returns: The deserialized sequence of scalar values as leaves. """ - - # validate all parameters - paraminfos = [ - _create_v0_restore_paraminfo(p, self._context, deserialization_context) - for p in params - ] - - async def _background_deserialize() -> Sequence[str]: - # This is needed because StringHandler.deserialize could return None - # values. However, it should be very rare. This is to make sure - # everything is string. - return [p or "" for p in await self._handler_impl.deserialize(paraminfos)] - - return asyncio.create_task(_background_deserialize()) + return self._background_deserialize(params, deserialization_context) async def metadata( self, @@ -175,13 +170,7 @@ async def metadata( Returns: Sequence of StringMetadata for each provided DeserializationParam. """ - paraminfos = [ - _create_v0_restore_paraminfo(p, self._context, deserialization_context) - for p in params - ] - - async def _get_metadata() -> Sequence[AbstractString]: - v0_metadatas = await self._handler_impl.metadata(paraminfos) - return ["string"] * len(v0_metadatas) + return ['string'] * len(params) - return await _get_metadata() + def typestr(self) -> str: + return 'str' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py index a9ecd099f..58d872a57 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py @@ -211,6 +211,10 @@ async def metadata( """ ... + def typestr(self) -> str: + """A string used to represent the leaf type.""" + ... + LeafHandlerRegistryItem = Tuple[ Type[Leaf], Type[AbstractLeaf], Type[LeafHandler[Leaf, AbstractLeaf]] diff --git a/docs/guides/checkpoint/v1/customization.ipynb b/docs/guides/checkpoint/v1/customization.ipynb index 5d496502d..b4533f934 100644 --- a/docs/guides/checkpoint/v1/customization.ipynb +++ b/docs/guides/checkpoint/v1/customization.ipynb @@ -1007,7 +1007,10 @@ " **{k: getattr(__builtins__, v) for k, v in contents.items()}\n", " )\n", " )\n", - " return ret" + " return ret\n", + " \n", + " def typestr(self) -> str:\n", + " return 'Point'" ], "outputs": [], "execution_count": 16