Skip to content

Commit 10ee277

Browse files
Added dispatch_apply and error to Parameters creation.
1 parent c1442cf commit 10ee277

File tree

3 files changed

+172
-11
lines changed

3 files changed

+172
-11
lines changed

src/synthesizrr/base/framework/trainer/RayTuneTrainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,7 @@ def get_final_metrics_stats(
15411541
continue
15421542
final_dataset_metrics[dataset_metric.display_name]: Dict[str, Union[int, float, Dict]] = {
15431543
'mean': np.mean(final_dataset_metrics[dataset_metric.display_name]),
1544+
'median': np.median(final_dataset_metrics[dataset_metric.display_name]),
15441545
'std': np.std(final_dataset_metrics[dataset_metric.display_name], ddof=1), ## Unbiased
15451546
'min': np.min(final_dataset_metrics[dataset_metric.display_name]),
15461547
'max': np.max(final_dataset_metrics[dataset_metric.display_name]),

src/synthesizrr/base/util/concurrency.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import ray
1616
from ray.exceptions import GetTimeoutError
1717
from ray.util.dask import RayDaskCallback
18-
from pydantic import validate_arguments, conint, confloat
19-
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, Parameters
18+
from pydantic import conint, confloat
19+
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, Parameters, \
20+
is_list_or_set_like, is_dict_like, PandasSeries, filter_kwargs
2021
from synthesizrr.base.constants.DataProcessingConstants import Parallelize, FailureAction, Status, COMPLETED_STATUSES
21-
2222
from functools import partial
2323
## Jupyter-compatible asyncio usage:
2424
import asyncio
@@ -445,6 +445,117 @@ def dispatch_executor(
445445
return None
446446

447447

448+
def dispatch_apply(
449+
struct: Union[List, Tuple, np.ndarray, PandasSeries, Set, frozenset, Dict],
450+
*args,
451+
fn: Callable,
452+
parallelize: Parallelize,
453+
forward_parallelize: bool = False,
454+
item_wait: Optional[float] = None,
455+
iter_wait: Optional[float] = None,
456+
iter: bool = False,
457+
**kwargs
458+
) -> Any:
459+
parallelize: Parallelize = Parallelize.from_str(parallelize)
460+
item_wait: float = get_default(
461+
item_wait,
462+
{
463+
Parallelize.ray: _RAY_ACCUMULATE_ITEM_WAIT,
464+
Parallelize.processes: _LOCAL_ACCUMULATE_ITEM_WAIT,
465+
Parallelize.threads: _LOCAL_ACCUMULATE_ITEM_WAIT,
466+
Parallelize.asyncio: 0.0,
467+
Parallelize.sync: 0.0,
468+
}[parallelize]
469+
)
470+
iter_wait: float = get_default(
471+
iter_wait,
472+
{
473+
Parallelize.ray: _RAY_ACCUMULATE_ITER_WAIT,
474+
Parallelize.processes: _LOCAL_ACCUMULATE_ITER_WAIT,
475+
Parallelize.threads: _LOCAL_ACCUMULATE_ITER_WAIT,
476+
Parallelize.asyncio: 0.0,
477+
Parallelize.sync: 0.0,
478+
}[parallelize]
479+
)
480+
if forward_parallelize:
481+
kwargs['parallelize'] = parallelize
482+
executor: Optional = dispatch_executor(
483+
parallelize=parallelize,
484+
**kwargs,
485+
)
486+
try:
487+
set_param_from_alias(kwargs, param='progress_bar', alias=['progress', 'pbar'], default=True)
488+
progress_bar: Union[ProgressBar, Dict, bool] = kwargs.pop('progress_bar', False)
489+
submit_pbar: ProgressBar = ProgressBar.of(
490+
progress_bar,
491+
total=len(struct),
492+
desc='Submitting',
493+
prefer_kwargs=False,
494+
unit='item',
495+
)
496+
collect_pbar: ProgressBar = ProgressBar.of(
497+
progress_bar,
498+
total=len(struct),
499+
desc='Collecting',
500+
prefer_kwargs=False,
501+
unit='item',
502+
)
503+
if is_list_or_set_like(struct):
504+
futs = []
505+
for v in struct:
506+
def submit_task(item, **dispatch_kwargs):
507+
return fn(item, **dispatch_kwargs)
508+
509+
futs.append(
510+
dispatch(
511+
fn=submit_task,
512+
item=v,
513+
parallelize=parallelize,
514+
executor=executor,
515+
delay=item_wait,
516+
**filter_kwargs(fn, **kwargs),
517+
)
518+
)
519+
submit_pbar.update(1)
520+
elif is_dict_like(struct):
521+
futs = {}
522+
for k, v in struct.items():
523+
def submit_task(item, **dispatch_kwargs):
524+
return fn(item, **dispatch_kwargs)
525+
526+
futs[k] = dispatch(
527+
fn=submit_task,
528+
key=k,
529+
item=v,
530+
parallelize=parallelize,
531+
executor=executor,
532+
delay=item_wait,
533+
**filter_kwargs(fn, **kwargs),
534+
)
535+
submit_pbar.update(1)
536+
else:
537+
raise NotImplementedError(f'Unsupported type: {type_str(struct)}')
538+
submit_pbar.success()
539+
if iter:
540+
return accumulate_iter(
541+
futs,
542+
item_wait=item_wait,
543+
iter_wait=iter_wait,
544+
progress_bar=collect_pbar,
545+
**kwargs
546+
)
547+
else:
548+
return accumulate(
549+
futs,
550+
item_wait=item_wait,
551+
iter_wait=iter_wait,
552+
progress_bar=collect_pbar,
553+
**kwargs
554+
)
555+
finally:
556+
stop_executor(executor)
557+
558+
448559
def get_result(
449560
x,
450561
*,
@@ -785,7 +896,6 @@ def wait(
785896
wait_if_future(futures)
786897

787898

788-
@validate_arguments
789899
def retry(
790900
fn,
791901
*args,

src/synthesizrr/base/util/language.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -602,13 +602,19 @@ def _create_param(p: inspect.Parameter) -> inspect.Parameter:
602602
wrapper.__signature__ = sig
603603
wrapper.__annotations__ = {f"{n}_" if n in names_to_fix else n: v for n, v in f.__annotations__.items()}
604604

605-
return validate_arguments(
606-
wrapper,
607-
config={
608-
"allow_population_by_field_name": True,
609-
"arbitrary_types_allowed": True,
610-
}
611-
)
605+
try:
606+
return validate_arguments(
607+
wrapper,
608+
config={
609+
"allow_population_by_field_name": True,
610+
"arbitrary_types_allowed": True,
611+
}
612+
)
613+
except Exception as e:
614+
raise ValueError(
615+
f'Error creating model for function {get_fn_spec(f).resolved_name}.'
616+
f'\nEncountered Exception: {format_exception_msg(e)}'
617+
)
612618

613619

614620
def not_impl(
@@ -1531,6 +1537,27 @@ def invert_dict(d: Dict) -> Dict:
15311537
return d_inv
15321538

15331539

1540+
def iter_dict(d, depth: int = 1, *, _cur_depth: int = 0):
1541+
"""
1542+
Recursively iterate over nested dictionaries and yield keys at each depth.
1543+
1544+
:param d: The dictionary to iterate over.
1545+
:param depth: The current depth of recursion (used for tracking depth of keys).
1546+
:return: Yields tuples where the first elements are keys at different depths, and the last element is the value.
1547+
"""
1548+
assert isinstance(d, dict), f'Input must be a dictionary, found: {type(d)}'
1549+
assert isinstance(depth, int) and depth >= 1, f'depth must be an integer (1 or more)'
1550+
1551+
for k, v in d.items():
1552+
if isinstance(v, dict) and _cur_depth < depth - 1:
1553+
# If the value is a dictionary, recurse
1554+
for subkeys in iter_dict(v, _cur_depth=_cur_depth + 1, depth=depth):
1555+
yield (k,) + subkeys
1556+
else:
1557+
# If the value is not a dictionary, yield the key-value pair
1558+
yield (k, v)
1559+
1560+
15341561
## ======================== NumPy utils ======================== ##
15351562
def is_numpy_integer_array(data: Any) -> bool:
15361563
if not isinstance(data, np.ndarray):
@@ -2625,6 +2652,15 @@ class Parameters(BaseModel, ABC):
26252652
aliases: ClassVar[Tuple[str, ...]] = tuple()
26262653
dict_exclude: ClassVar[Tuple[str, ...]] = tuple()
26272654

2655+
def __init__(self, *args, **kwargs):
2656+
try:
2657+
super().__init__(*args, **kwargs)
2658+
except Exception as e:
2659+
raise ValueError(
2660+
f'Cannot create Pydantic instance of type "{self.class_name}".'
2661+
f'\nEncountered exception: {format_exception_msg(e)}'
2662+
)
2663+
26282664
@classproperty
26292665
def class_name(cls) -> str:
26302666
return str(cls.__name__) ## Will return the child class name.
@@ -3227,6 +3263,15 @@ def create_progress_bar(
32273263
smoothing=smoothing,
32283264
**kwargs
32293265
)
3266+
elif style == 'ray':
3267+
from ray.experimental import tqdm_ray
3268+
kwargs = filter_keys(
3269+
kwargs,
3270+
keys=set(get_fn_spec(tqdm_ray.tqdm).args + get_fn_spec(tqdm_ray.tqdm).kwargs),
3271+
how='include',
3272+
)
3273+
from ray.experimental import tqdm_ray
3274+
return tqdm_ray.tqdm(**kwargs)
32303275
else:
32313276
return StdTqdmProgressBar(
32323277
ncols=ncols,
@@ -3311,6 +3356,11 @@ def ignore_all_output():
33113356
yield
33123357

33133358

3359+
@contextmanager
3360+
def ignore_nothing():
3361+
yield
3362+
3363+
33143364
# from pydantic import Field, AliasChoices
33153365
# def Alias(*, default: Optional[Any] = None, alias: Union[Tuple[str, ...], List[str], Set[str], str]):
33163366
# alias: AliasChoices = AliasChoices(*as_tuple(alias))

0 commit comments

Comments
 (0)