Skip to content

Commit 29e9307

Browse files
Add "AssetsClient" (#38)
1 parent d402c2c commit 29e9307

17 files changed

Lines changed: 1168 additions & 5 deletions

File tree

armis_sdk/clients/assets_client.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import datetime
2+
from typing import AsyncIterator
3+
from typing import Literal
4+
from typing import Optional
5+
from typing import Type
6+
from typing import Union
7+
8+
import universalasync
9+
10+
from armis_sdk.core import response_utils
11+
from armis_sdk.core.armis_error import ArmisError
12+
from armis_sdk.core.armis_error import BulkUpdateError
13+
from armis_sdk.core.armis_error import BulkUpdateItemError
14+
from armis_sdk.core.base_entity_client import BaseEntityClient
15+
from armis_sdk.entities.asset import Asset
16+
from armis_sdk.entities.asset import AssetT
17+
from armis_sdk.entities.device import Device
18+
19+
AssetIdSource = Literal[
20+
"ASSET_ID",
21+
"IPV4_ADDRESS",
22+
"IPV6_ADDRESS",
23+
"MAC_ADDRESS",
24+
"SERIAL_NUMBER",
25+
]
26+
27+
28+
@universalasync.wrap
29+
class AssetsClient(BaseEntityClient): # pylint: disable=too-few-public-methods
30+
# pylint: disable=line-too-long
31+
"""
32+
A client for interacting with assets.
33+
34+
The primary entities for this client inherit from [Asset][armis_sdk.entities.asset.Asset]:
35+
36+
1. [Device][armis_sdk.entities.device.Device]
37+
"""
38+
39+
async def list_by_asset_id(
40+
self,
41+
asset_class: Type[AssetT],
42+
asset_ids: Union[list[int], list[str]],
43+
asset_id_source: AssetIdSource = "ASSET_ID",
44+
fields: Optional[list[str]] = None,
45+
) -> AsyncIterator[AssetT]:
46+
"""List assets by asset ID or other identifiers.
47+
48+
Args:
49+
asset_class: The asset class to list. Must inherit from [Asset][armis_sdk.entities.asset.Asset].
50+
asset_ids: A list of asset identifiers (int or str depending on asset_id_source).
51+
asset_id_source: The type of identifier provided in asset_ids.
52+
fields: Optional list of fields to retrieve. If None, all non-custom fields are retrieved.
53+
54+
Yields:
55+
Assets of the specified class matching the provided identifiers.
56+
57+
Example:
58+
```python linenums="1" hl_lines="13 17"
59+
import asyncio
60+
61+
from armis_sdk.clients.assets_client import AssetsClient
62+
from armis_sdk.entities.device import Device
63+
64+
async def main():
65+
assets_client = AssetsClient()
66+
67+
device_ids = [1, 2, 3]
68+
ipv4_addresses = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
69+
70+
# List by the default source "ASSET_ID"
71+
async for device in assets_client.list_by_asset_id(Device, device_ids):
72+
print(device)
73+
74+
# List by explicit source "IPV4_ADDRESS"
75+
async for device in assets_client.list_by_asset_id(Device, ipv4_addresses, asset_id_source="IPV4_ADDRESS"):
76+
print(device)
77+
78+
asyncio.run(main())
79+
```
80+
"""
81+
filter_ = {
82+
"filter_criteria": "ASSET_ID",
83+
"asset_ids": asset_ids,
84+
"asset_id_source": asset_id_source,
85+
}
86+
async for item in self._list_assets(asset_class, fields, filter_):
87+
yield item
88+
89+
async def list_by_last_seen(
90+
self,
91+
asset_class: Type[AssetT],
92+
last_seen: Union[datetime.datetime, datetime.timedelta],
93+
fields: Optional[list[str]] = None,
94+
) -> AsyncIterator[AssetT]:
95+
"""List assets by last seen timestamp.
96+
97+
Args:
98+
asset_class: The asset class to list. Must inherit from [Asset][armis_sdk.entities.asset.Asset].
99+
last_seen: Either a datetime (assets seen on or after this time) or timedelta (assets seen within this duration).
100+
fields: Optional list of fields to retrieve. If None, all non-custom fields are retrieved.
101+
102+
Yields:
103+
Assets of the specified class matching the last seen criteria.
104+
105+
Raises:
106+
ArmisError: If last_seen is neither datetime nor timedelta.
107+
108+
Example:
109+
```python linenums="1" hl_lines="11 15"
110+
import asyncio
111+
import datetime
112+
113+
from armis_sdk.clients.assets_client import AssetsClient
114+
from armis_sdk.entities.device import Device
115+
116+
async def main():
117+
assets_client = AssetsClient()
118+
119+
# List devices seen in the last 24 hours
120+
async for device in assets_client.list_by_last_seen(Device, datetime.timedelta(days=1)):
121+
print(device)
122+
123+
# List devices seen on or after December 8, 2025
124+
async for device in assets_client.list_by_last_seen(Device, datetime.datetime(2025, 12, 8)):
125+
print(device)
126+
127+
asyncio.run(main())
128+
```
129+
"""
130+
filter_: dict[str, Union[str, int]] = {"filter_criteria": "LAST_SEEN"}
131+
132+
if isinstance(last_seen, datetime.datetime):
133+
filter_["last_seen_ge"] = last_seen.isoformat()
134+
elif isinstance(last_seen, datetime.timedelta):
135+
filter_["last_seen_seconds"] = int(last_seen.total_seconds())
136+
else:
137+
raise ArmisError(f"Invalid 'last_seen' type {type(last_seen)}")
138+
139+
async for item in self._list_assets(asset_class, fields, filter_):
140+
yield item
141+
142+
async def update(
143+
self,
144+
assets: list[AssetT],
145+
fields: list[str],
146+
asset_id_source: AssetIdSource = "ASSET_ID",
147+
) -> None:
148+
# pylint: disable=line-too-long
149+
"""Bulk update assets.
150+
151+
Args:
152+
assets: A list of assets. Items must inherit from [Asset][armis_sdk.entities.asset.Asset].
153+
fields: A list of fields to update. Currently only custom properties are supported (i.e. `custom.MyField`).
154+
asset_id_source: From where on the asset to take the unique identifier.
155+
156+
Raises:
157+
BulkUpdateError: If an error occurs while trying to update any of the assets.
158+
159+
Example:
160+
```python linenums="1" hl_lines="13 16"
161+
import asyncio
162+
163+
from armis_sdk.clients.assets_client import AssetsClient
164+
from armis_sdk.entities.device import Device
165+
166+
167+
async def main():
168+
assets_client = AssetsClient()
169+
170+
device = Device(device_id=1, ipv4_addresses=["1.2.3.4"], custom={"MyField": "Hello, World"})
171+
172+
# Update based on the default source "ASSET_ID"
173+
await assets_client.update([device], ["custom.MyField"])
174+
175+
# Update based on the explicit source "IPV4_ADDRESS"
176+
await assets_client.update([device], ["custom.MyField"], asset_id_source="IPV4_ADDRESS")
177+
178+
asyncio.run(main())
179+
```
180+
"""
181+
if not assets or not fields:
182+
return
183+
184+
self._validate_asset_class(assets)
185+
186+
asset_class = type(assets[0])
187+
self._validate_fields(asset_class, fields, allow_model_members=False)
188+
189+
items = []
190+
for index, asset in enumerate(assets):
191+
asset_id = self._get_asset_id(asset, index, asset_id_source)
192+
for field in fields:
193+
items.append(self._create_bulk_update_request(asset, asset_id, field))
194+
195+
if not items:
196+
return
197+
198+
payload = {
199+
"items": items,
200+
"asset_type": asset_class.asset_type,
201+
"asset_id_source": asset_id_source,
202+
}
203+
async with self._armis_client.client() as client:
204+
response = await client.post("/v3/assets/_bulk", json=payload)
205+
data = response_utils.get_data_dict(response)
206+
errors = [
207+
BulkUpdateItemError(index=index, request=items[index], response=item)
208+
for index, item in enumerate(data["items"])
209+
if item["status"] != 202
210+
]
211+
if errors:
212+
raise BulkUpdateError(errors)
213+
214+
@classmethod
215+
def _create_bulk_update_request(
216+
cls,
217+
asset: Asset,
218+
asset_id: Union[str, int],
219+
field: str,
220+
):
221+
request = {"asset_id": asset_id, "key": field}
222+
if cls._is_custom_field(field):
223+
key = field.split(".", 1)[1]
224+
if value := asset.custom.get(key):
225+
request["operation"] = "SET"
226+
request["value"] = value
227+
else:
228+
request["operation"] = "UNSET"
229+
else:
230+
raise ArmisError(f"Updating the field {field!r} is currently not supported")
231+
232+
return request
233+
234+
@classmethod
235+
def _get_asset_id(
236+
cls,
237+
asset: Asset,
238+
index: int,
239+
asset_id_source: AssetIdSource,
240+
) -> Union[str, int]:
241+
if isinstance(asset, Device):
242+
return cls._get_device_asset_id(asset, index, asset_id_source)
243+
244+
raise ArmisError(f"Can't get {asset_id_source} of asset {asset!r}")
245+
246+
@classmethod
247+
def _get_device_asset_id(
248+
cls,
249+
device: Device,
250+
index: int,
251+
asset_id_source: AssetIdSource,
252+
):
253+
if asset_id_source == "ASSET_ID":
254+
if device.device_id is None:
255+
raise ArmisError(f"Device at index {index} doesn't have a device id")
256+
return device.device_id
257+
258+
if asset_id_source == "MAC_ADDRESS":
259+
if device.mac_addresses is None or len(device.mac_addresses) != 1:
260+
raise ArmisError(
261+
f"Device at index {index} doesn't have exactly one mac address"
262+
)
263+
return device.mac_addresses[0]
264+
265+
if asset_id_source == "IPV4_ADDRESS":
266+
if device.ipv4_addresses is None or len(device.ipv4_addresses) != 1:
267+
raise ArmisError(
268+
f"Device at index {index} doesn't have exactly one IPv4 address"
269+
)
270+
return device.ipv4_addresses[0]
271+
272+
if asset_id_source == "IPV6_ADDRESS":
273+
if device.ipv6_addresses is None or len(device.ipv6_addresses) != 1:
274+
raise ArmisError(
275+
f"Device at index {index} doesn't have exactly one IPv6 address"
276+
)
277+
return device.ipv6_addresses[0]
278+
279+
if asset_id_source == "SERIAL_NUMBER":
280+
if device.serial_numbers is None or len(device.serial_numbers) != 1:
281+
raise ArmisError(
282+
f"Device at index {index} doesn't have exactly one serial number"
283+
)
284+
return device.serial_numbers[0]
285+
286+
raise ArmisError(f"Can't get {asset_id_source!r} of device at index {index}")
287+
288+
@classmethod
289+
def _is_custom_field(cls, field: str) -> bool:
290+
return field.startswith("custom.")
291+
292+
async def _list_assets(
293+
self,
294+
asset_class: Type[AssetT],
295+
fields: Optional[list[str]],
296+
filter_: dict,
297+
) -> AsyncIterator[AssetT]:
298+
fields = fields or sorted(asset_class.all_fields())
299+
300+
self._validate_fields(asset_class, fields)
301+
302+
body = {
303+
"asset_type": asset_class.asset_type,
304+
"fields": fields,
305+
"filter": filter_,
306+
}
307+
async for item in self._armis_client.list("/v3/assets/_search", body=body):
308+
yield asset_class.from_search_result(item)
309+
310+
@classmethod
311+
def _validate_asset_class(cls, assets: list[AssetT]):
312+
asset_types = {type(asset) for asset in assets}
313+
if len(asset_types) > 1:
314+
asset_types_str = ", ".join(sorted(repr(at.__name__) for at in asset_types))
315+
raise ArmisError(
316+
"All assets must be of the same type, "
317+
f"got {len(asset_types)} types: {asset_types_str}"
318+
)
319+
320+
@classmethod
321+
def _validate_fields(
322+
cls,
323+
asset_class: Type[AssetT],
324+
fields: list[str],
325+
allow_model_members=True,
326+
):
327+
invalid_fields = []
328+
all_fields = asset_class.all_fields()
329+
for field in fields:
330+
if cls._is_custom_field(field):
331+
continue
332+
333+
if allow_model_members and field in all_fields:
334+
continue
335+
336+
invalid_fields.append(field)
337+
338+
if invalid_fields:
339+
fields_str = ", ".join(map(repr, invalid_fields))
340+
raise ArmisError(
341+
f"The following fields are not supported with this operation: {fields_str}"
342+
)

armis_sdk/core/armis_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ def client(self, retries: Optional[int] = None, backoff: Optional[float] = None)
8282
trust_env=True,
8383
)
8484

85-
async def list(self, url: str) -> AsyncIterator[dict]:
85+
async def list(self, url: str, body: Optional[dict] = None) -> AsyncIterator[dict]:
8686
"""List all items from a paginated endpoint.
8787
8888
Args:
8989
url (str): The relative endpoint URL.
90+
body (dict): Payload to send as POST request.
9091
9192
Returns:
9293
An (async) iterator of `dict`s.
@@ -113,9 +114,12 @@ async def main():
113114
"""
114115
page_size = int(os.getenv(ARMIS_PAGE_SIZE, str(DEFAULT_PAGE_LENGTH)))
115116
async with self.client() as client:
116-
params = {"limit": page_size}
117+
params = {"limit": page_size, **(body or {})}
117118
while True:
118-
response = await client.get(url, params=params)
119+
if body:
120+
response = await client.post(url, json=params)
121+
else:
122+
response = await client.get(url, params=params)
119123
data = response_utils.get_data_dict(response)
120124
items = data["items"]
121125
for item in items:

0 commit comments

Comments
 (0)