Source code for dagster._core.storage.upath_io_manager

from abc import abstractmethod
from typing import Any, Dict, Mapping, Optional, Union

from upath import UPath

from dagster import (
    InputContext,
    MetadataValue,
    MultiPartitionKey,
    OutputContext,
    _check as check,
)
from dagster._core.storage.memoizable_io_manager import MemoizableIOManager


[docs]class UPathIOManager(MemoizableIOManager): """Abstract IOManager base class compatible with local and cloud storage via `universal-pathlib` and `fsspec`. Features: - handles partitioned assets - handles loading a single upstream partition - handles loading multiple upstream partitions (with respect to <PyObject object="PartitionMapping" />) - the `get_metadata` method can be customized to add additional metadata to the output - the `allow_missing_partitions` metadata value can be set to `True` to skip missing partitions (the default behavior is to raise an error) """ extension: str = "" # override in child class def __init__( self, base_path: UPath, ): assert self.extension == "" or "." in self.extension self._base_path = base_path @abstractmethod def dump_to_path(self, context: OutputContext, obj: Any, path: UPath): """Child classes should override this method to write the object to the filesystem.""" @abstractmethod def load_from_path(self, context: InputContext, path: UPath) -> Any: """Child classes should override this method to load the object from the filesystem.""" def get_metadata( self, context: OutputContext, obj: Any, ) -> Dict[str, MetadataValue]: """Child classes should override this method to add custom metadata to the outputs.""" return {} def has_output(self, context: OutputContext) -> bool: return self._get_path(context).exists() def _with_extension(self, path: UPath) -> UPath: # Can't just call path.with_suffix(self.extension) because if # self.extension is "" then this trims off any extension that # was in the path previously. return path.parent / (path.name + self.extension) def _get_path_without_extension(self, context: Union[InputContext, OutputContext]) -> UPath: if context.has_asset_key: # we are dealing with an asset # we are not using context.get_asset_identifier() because it already includes the partition_key context_path = list(context.asset_key.path) else: # we are dealing with an op output context_path = list(context.get_identifier()) return self._base_path.joinpath(*context_path) def _get_path(self, context: Union[InputContext, OutputContext]) -> UPath: """Returns the I/O path for a given context. Should not be used with partitions (use `_get_paths_for_partitions` instead). """ path = self._get_path_without_extension(context) return self._with_extension(path) def _get_paths_for_partitions( self, context: Union[InputContext, OutputContext] ) -> Dict[str, UPath]: """Returns a dict of partition_keys into I/O paths for a given context.""" if not context.has_asset_partitions: raise TypeError( f"Detected {context.dagster_type.typing_type} input type " "but the asset is not partitioned" ) def _formatted_multipartitioned_path(partition_key: MultiPartitionKey) -> str: ordered_dimension_keys = [ key[1] for key in sorted(partition_key.keys_by_dimension.items(), key=lambda x: x[0]) ] return "/".join(ordered_dimension_keys) formatted_partition_keys = [ _formatted_multipartitioned_path(pk) if isinstance(pk, MultiPartitionKey) else pk for pk in context.asset_partition_keys ] asset_path = self._get_path_without_extension(context) return { partition_key: self._with_extension(asset_path / partition_key) for partition_key in formatted_partition_keys } def _get_multipartition_backcompat_paths( self, context: Union[InputContext, OutputContext] ) -> Mapping[str, UPath]: if not context.has_asset_partitions: raise TypeError( f"Detected {context.dagster_type.typing_type} input type " "but the asset is not partitioned" ) partition_keys = context.asset_partition_keys asset_path = self._get_path_without_extension(context) return { partition_key: self._with_extension(asset_path / partition_key) for partition_key in partition_keys if isinstance(partition_key, MultiPartitionKey) } def _load_single_input( self, path: UPath, context: InputContext, backcompat_path: Optional[UPath] = None ) -> Any: context.log.debug(f"Loading file from: {path}") try: obj = self.load_from_path(context=context, path=path) except FileNotFoundError as e: if backcompat_path is not None: try: obj = self.load_from_path(context=context, path=backcompat_path) context.log.debug( f"File not found at {path}. Loaded instead from backcompat path:" f" {backcompat_path}" ) except FileNotFoundError: raise e else: raise e context.add_input_metadata({"path": MetadataValue.path(str(path))}) return obj def _load_multiple_inputs(self, context: InputContext) -> Dict[str, Any]: # load multiple partitions allow_missing_partitions = ( context.metadata.get("allow_missing_partitions", False) if context.metadata is not None else False ) objs: Dict[str, Any] = {} paths = self._get_paths_for_partitions(context) backcompat_paths = self._get_multipartition_backcompat_paths(context) context.log.debug(f"Loading {len(paths)} partitions...") for partition_key, path in paths.items(): context.log.debug(f"Loading partition from {path} using {self.__class__.__name__}") try: obj = self.load_from_path(context=context, path=path) objs[partition_key] = obj except FileNotFoundError as e: backcompat_path = backcompat_paths.get(partition_key) if backcompat_path is not None: obj = self.load_from_path(context=context, path=backcompat_path) objs[partition_key] = obj if not allow_missing_partitions and objs.get(partition_key) is None: raise e context.log.debug( f"Couldn't load partition {path} and skipped it " "because the input metadata includes allow_missing_partitions=True" ) # TODO: context.add_output_metadata fails in the partitioned context. this should be fixed? return objs def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: if not context.has_asset_key: # we are dealing with an op output which is always non-partitioned path = self._get_path(context) return self._load_single_input(path, context) else: if not context.has_asset_partitions: # we are dealing with a non-partitioned asset path = self._get_path(context) return self._load_single_input(path, context) else: asset_partition_keys = context.asset_partition_keys if len(asset_partition_keys) == 0: return None elif len(asset_partition_keys) == 1: paths = self._get_paths_for_partitions(context) check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") path = list(paths.values())[0] backcompat_paths = self._get_multipartition_backcompat_paths(context) backcompat_path = ( None if not backcompat_paths else list(backcompat_paths.values())[0] ) return self._load_single_input(path, context, backcompat_path) else: # we are dealing with multiple partitions of an asset type_annotation = context.dagster_type.typing_type if type_annotation != Any and not is_dict_type(type_annotation): check.failed( "Loading an input that corresponds to multiple partitions, but the" " type annotation on the op input is not a dict, Dict, Mapping, or" f" Any: is '{type_annotation}'." ) return self._load_multiple_inputs(context) def handle_output(self, context: OutputContext, obj: Any): if context.dagster_type.typing_type == type(None): check.invariant( obj is None, ( "Output had Nothing type or 'None' annotation, but handle_output received" f" value that was not None and was of type {type(obj)}." ), ) return None if context.has_asset_partitions: paths = self._get_paths_for_partitions(context) assert len(paths) == 1 path = list(paths.values())[0] else: path = self._get_path(context) path.parent.mkdir(parents=True, exist_ok=True) context.log.debug(f"Writing file at: {path}") self.dump_to_path(context=context, obj=obj, path=path) metadata = {"path": MetadataValue.path(str(path))} custom_metadata = self.get_metadata(context=context, obj=obj) metadata.update(custom_metadata) # type: ignore context.add_output_metadata(metadata)
def is_dict_type(type_obj) -> bool: if type_obj == dict: return True if hasattr(type_obj, "__origin__") and type_obj.__origin__ in (dict, Dict, Mapping): return True return False