Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- #v1 Add a fresh implementation of `StringLeafHandler` to reduce unnecessary
delegation complexity. Also introduce a `typestr` function for `LeafHandler`.

## [0.11.31] - 2025-12-11

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,6 @@ async def _convert_to_array_metadata() -> Sequence[ArrayMetadata]:
return ret

return await _convert_to_array_metadata()

def typestr(self) -> str:
return 'jax.Array'
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,6 @@ def get_v0_type_handler_registry(
type handler registry.
context: The Context to be used to default construct the LeafHandlers.
"""

def _get_typestr(leaf_type: Any) -> str:
if leaf_type == jax.Array:
return type_handlers_v0.JAX_ARRAY_TYPE_STR
elif leaf_type == np.ndarray:
return 'np.ndarray'
elif leaf_type in (int, float, bytes, np.number):
return 'scalar'
elif leaf_type == str:
return 'string'
else:
return f'{leaf_type!r}'

# register standardard v1 leaf handlers to the v0 type handler registry.
handlers = []
for leaf_type, _, leaf_handler_type in leaf_handler_registry.get_all():
Expand All @@ -455,7 +442,7 @@ def _get_typestr(leaf_type: Any) -> str:
leaf_type,
CompatibleTypeHandler(
leaf_handler,
typestr=_get_typestr(leaf_type),
typestr=leaf_handler.typestr(),
),
))
return type_handler_registry.create_type_handler_registry(*handlers)
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,6 @@ async def _convert_to_numpy_metadata() -> Sequence[NumpyMetadata]:
return ret

return await _convert_to_numpy_metadata()

def typestr(self) -> str:
return 'np.ndarray'
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,6 @@ def _get_type(meta: type_handlers_v0.ScalarMetadata):
return ret

return await _convert_to_scalar_metadata()

def typestr(self) -> str:
return "scalar"
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
"""

import asyncio
from typing import Awaitable, Sequence, Type
from typing import Any, Awaitable, Sequence

from absl import logging
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
import tensorstore as ts


AbstractString = types.AbstractString
StringSerializationParam = types.SerializationParam[str]
Expand All @@ -34,54 +37,6 @@
]


def _create_v0_saving_paraminfo(
param: StringSerializationParam,
context: context_lib.Context,
serialization_context: types.SerializationContext,
) -> type_handlers_v0.ParamInfo:
"""Creates a V0 ParamInfo from V1 params and contexts for saving."""

saving_options = context.array_options.saving

return type_handlers_v0.ParamInfo(
name=param.name,
parent_dir=serialization_context.parent_dir.path,
byte_limiter=serialization_context.byte_limiter,
is_ocdbt_checkpoint=saving_options.use_ocdbt,
use_zarr3=saving_options.use_zarr3,
ocdbt_target_data_file_size=saving_options.ocdbt_target_data_file_size,
ts_context=serialization_context.ts_context,
value_typestr="string",
)


def _create_v0_restore_paraminfo(
param: types.DeserializationParam[
AbstractString | Type[AbstractString] | None
],
context: context_lib.Context,
deserialization_context: types.DeserializationContext,
) -> type_handlers_v0.ParamInfo:
"""Creates a V0 ParamInfo from V1 params and contexts for loading."""

loading_options = context.array_options.loading

return type_handlers_v0.ParamInfo(
name=param.name,
parent_dir=deserialization_context.parent_dir,
skip_deserialize=False,
byte_limiter=deserialization_context.byte_limiter,
is_ocdbt_checkpoint=deserialization_context.ocdbt_checkpoint,
ts_context=deserialization_context.ts_context,
raise_array_data_missing_error=loading_options.raise_array_data_missing_error,
use_zarr3=deserialization_context.zarr3_checkpoint,
)


async def _async_futures(commit_futures: Sequence[future.Future]):
await asyncio.gather(*[asyncio.to_thread(f.result) for f in commit_futures])


class StringLeafHandler(types.LeafHandler[str, AbstractString]):
""":py:class:`.StringLeafHandler` that implements the :py:class:`~.v1.serialization.LeafHandler` Protocol."""

Expand All @@ -98,9 +53,46 @@ def __init__(
context: Context that will be used for this leaf handler.
"""
self._context = context_lib.get_context(context)
self._handler_impl = type_handlers_v0.StringHandler()
self._filename = '_strings.json'
logging.vlog(1, 'StringLeafHandler created.')

logging.vlog(1, "StringLeafHandler created.")
def _get_json_tspec(
self,
param_name: str,
parent_dir: path_types.Path,
) -> dict[str, Any]:
"""Gets Tensorstore spec in JSON format."""
directory = (parent_dir / self._filename).as_posix()
kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False)
tspec = {
'driver': 'json',
'kvstore': kvstore_tspec,
'json_pointer': '/' + param_name,
}
return tspec

async def _background_serialize(
self,
params: Sequence[StringSerializationParam],
serialization_context: types.SerializationContext,
):
"""Writes strings using Tensorstore in the background thread."""
parent_dir = await serialization_context.parent_dir.await_creation()
write_coros = []
txn = ts.Transaction()
for param in params:
tspec = self._get_json_tspec(param.name, parent_dir)
if multihost.is_primary_host(
self._context.multiprocessing_options.primary_host
):
t = await ts.open(
tspec,
open=True,
context=serialization_context.ts_context,
)
write_coros.append(t.with_transaction(txn).write(param.value)) # pytype: disable=attribute-error
await asyncio.gather(*write_coros)
await txn.commit_async()

async def serialize(
self,
Expand All @@ -117,18 +109,34 @@ async def serialize(
Sequence of commit futures which can be awaited to complete the save
operation.
"""
values = [p.value for p in params]
paraminfos = [
_create_v0_saving_paraminfo(p, self._context, serialization_context)
for p in params
]

# `args` is not used by StringHandler.serialize, so it's not passed in.
commit_futures = await self._handler_impl.serialize(values, paraminfos)
if not commit_futures:
raise ValueError("No commit futures returned by StringHandler.serialize.")
return self._background_serialize(params, serialization_context)

return _async_futures(commit_futures)
async def _background_deserialize(
self,
params: Sequence[StringDeserializationParam],
deserialization_context: types.DeserializationContext,
) -> Sequence[str]:
"""Deserializes strings using Tensorstore in the background thread."""

async def _convert_to_string(tensorstore):
result = await tensorstore.read()
return str(result)

open_futures = []
for param in params:
tspec = self._get_json_tspec(
param.name, deserialization_context.parent_dir
)
open_future = ts.open(
tspec,
open=True,
read=True,
context=deserialization_context.ts_context,
)
open_futures += [open_future]
tensorstores = await asyncio.gather(*open_futures)
read_ops = [_convert_to_string(t) for t in tensorstores]
return await asyncio.gather(*read_ops)

async def deserialize(
self,
Expand All @@ -145,20 +153,7 @@ async def deserialize(
Returns:
The deserialized sequence of scalar values as leaves.
"""

# validate all parameters
paraminfos = [
_create_v0_restore_paraminfo(p, self._context, deserialization_context)
for p in params
]

async def _background_deserialize() -> Sequence[str]:
# This is needed because StringHandler.deserialize could return None
# values. However, it should be very rare. This is to make sure
# everything is string.
return [p or "" for p in await self._handler_impl.deserialize(paraminfos)]

return asyncio.create_task(_background_deserialize())
return self._background_deserialize(params, deserialization_context)

async def metadata(
self,
Expand All @@ -175,13 +170,7 @@ async def metadata(
Returns:
Sequence of StringMetadata for each provided DeserializationParam.
"""
paraminfos = [
_create_v0_restore_paraminfo(p, self._context, deserialization_context)
for p in params
]

async def _get_metadata() -> Sequence[AbstractString]:
v0_metadatas = await self._handler_impl.metadata(paraminfos)
return ["string"] * len(v0_metadatas)
return ['string'] * len(params)

return await _get_metadata()
def typestr(self) -> str:
return 'str'
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ async def metadata(
"""
...

def typestr(self) -> str:
"""A string used to represent the leaf type."""
...


LeafHandlerRegistryItem = Tuple[
Type[Leaf], Type[AbstractLeaf], Type[LeafHandler[Leaf, AbstractLeaf]]
Expand Down
5 changes: 4 additions & 1 deletion docs/guides/checkpoint/v1/customization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,10 @@
" **{k: getattr(__builtins__, v) for k, v in contents.items()}\n",
" )\n",
" )\n",
" return ret"
" return ret\n",
" \n",
" def typestr(self) -> str:\n",
" return 'Point'"
],
"outputs": [],
"execution_count": 16
Expand Down
Loading