Source code for zcollection.dask_utils
# Copyright (c) 2023 CNES
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""
Dask utilities
==============
"""
from __future__ import annotations
from typing import Any, Callable, Iterator, Sequence
import itertools
import uuid
from dask.delayed import Delayed as dask_Delayed
import dask.distributed
import dask.highlevelgraph
[docs]
def dask_workers(client: dask.distributed.Client,
cores_only: bool = False) -> int:
"""Return the number of dask workers available.
Args:
client: dask client
cores_only: if True, only the number of cores is returned,
otherwise the total number of threads is returned.
Returns:
number of dask workers
Raises:
ValueError: If no dask workers are available.
"""
result: int = len(
client.ncores()) if cores_only else sum( # type: ignore[arg-type]
item
for item in client.nthreads().values()) # type: ignore[arg-type]
if result == 0:
raise RuntimeError('No dask workers available')
return result
[docs]
def get_client() -> dask.distributed.Client:
"""Return the default dask client.
Returns:
default dask client
"""
try:
return dask.distributed.get_client()
except ValueError:
return dask.distributed.Client(
processes=False,
direct_to_workers=True,
)
[docs]
def split_sequence(sequence: Sequence[Any],
sections: int | None = None) -> Iterator[Sequence[Any]]:
"""Split a sequence into sections.
Args:
sequence: The sequence to split.
sections: The number of sections to split the sequence into. Default
divides the sequence into n sections of one element.
Returns:
Iterator of sequences.
"""
sections = len(sequence) if sections is None else sections
if sections <= 0:
raise ValueError('The number of sections must be greater than zero.')
length: int = len(sequence)
sections = min(sections, length)
size: int
extras: int
size, extras = divmod(length, sections)
div = tuple(
itertools.accumulate([0] + extras * [size + 1] +
(sections - extras) * [size]))
yield from (sequence[item:div[ix + 1]] for ix, item in enumerate(div[:-1]))
[docs]
def simple_delayed(name: str, func: Callable) -> dask_Delayed:
"""Create a simple delayed function.
Args:
name: name of the function
func: function to be delayed
Returns:
delayed function
"""
name = f'{name}-{str(uuid.uuid4())}'
return dask_Delayed(
name,
dask.highlevelgraph.HighLevelGraph({name: {
name: func
}}, {name: set()}),
None,
)