Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import functools
import json
import re
from collections.abc import Iterable
from logging import getLogger
from typing import Any, Final
from typing import Any, Final, cast
from urllib.parse import urlsplit

from gokart.gcs_config import GCSConfig
Expand Down Expand Up @@ -125,10 +124,10 @@ def _get_patched_obj_metadata(
@staticmethod
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
if isinstance(required_task_outputs, RequiredTaskOutput):
return required_task_outputs.serialize()
return cast(FlattenableItems[str], required_task_outputs.serialize())
elif isinstance(required_task_outputs, dict):
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
elif isinstance(required_task_outputs, Iterable):
elif isinstance(required_task_outputs, list | tuple):
return [GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]
else:
raise TypeError(
Expand Down
15 changes: 11 additions & 4 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def output(self) -> FlattenableItems[TargetOnKart]:

def requires(self) -> FlattenableItems[TaskOnKart[Any]]:
tasks = self.make_task_instance_dictionary()
return tasks or [] # when tasks is empty dict, then this returns empty list.
if tasks:
return cast(FlattenableItems[TaskOnKart[Any]], tasks)
return [] # when tasks is empty dict, then this returns empty list.

def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]:
return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}
Expand Down Expand Up @@ -354,9 +356,14 @@ def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels
if isinstance(obj, pd.DataFrame) and obj.empty:
raise EmptyDumpError()

required_task_outputs = map_flattenable_items(
lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()),
self.requires(),
required_task_outputs = cast(
FlattenableItems[RequiredTaskOutput],
map_flattenable_items(
lambda task: map_flattenable_items(
lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()
),
self.requires(),
),
)

self._get_output_target(target).dump(
Expand Down
4 changes: 2 additions & 2 deletions gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def add_config(file_path: str) -> None:


T = TypeVar('T')
FlattenableItems: TypeAlias = T | Iterable['FlattenableItems[T]'] | dict[str, 'FlattenableItems[T]']
FlattenableItems: TypeAlias = T | list['FlattenableItems[T]'] | tuple['FlattenableItems[T]', ...] | dict[str, 'FlattenableItems[T]']
Comment thread
hirosassa marked this conversation as resolved.


def flatten(targets: FlattenableItems[T]) -> list[T]:
Expand Down Expand Up @@ -76,7 +76,7 @@ def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) ->
return tuple(map_flattenable_items(func, i) for i in items)
if isinstance(items, str):
return func(items) # type: ignore
if isinstance(items, Iterable):
if isinstance(items, list):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝
tuple is covered by L75

return list(map(lambda item: map_flattenable_items(func, item), items))
return func(items)

Expand Down