Source code for dagster._core.definitions.multi_dimensional_partitions

import itertools
from datetime import datetime
from functools import reduce
from typing import (
    Dict,
    Iterable,
    List,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    cast,
)

import dagster._check as check
from dagster._annotations import experimental
from dagster._core.errors import (
    DagsterInvalidDefinitionError,
    DagsterInvalidInvocationError,
    DagsterUnknownPartitionError,
)
from dagster._core.instance import DynamicPartitionsStore
from dagster._core.storage.tags import (
    MULTIDIMENSIONAL_PARTITION_PREFIX,
    get_multidimensional_partition_tag,
)

from .partition import (
    DefaultPartitionsSubset,
    DynamicPartitionsDefinition,
    Partition,
    PartitionsDefinition,
    PartitionsSubset,
    StaticPartitionsDefinition,
)
from .time_window_partitions import TimeWindowPartitionsDefinition

INVALID_STATIC_PARTITIONS_KEY_CHARACTERS = set(["|", ",", "[", "]"])

MULTIPARTITION_KEY_DELIMITER = "|"


class PartitionDimensionKey(
    NamedTuple("_PartitionDimensionKey", [("dimension_name", str), ("partition_key", str)])
):
    """Representation of a single dimension of a multi-dimensional partition key."""

    def __new__(cls, dimension_name: str, partition_key: str):
        return super(PartitionDimensionKey, cls).__new__(
            cls,
            dimension_name=check.str_param(dimension_name, "dimension_name"),
            partition_key=check.str_param(partition_key, "partition_key"),
        )


[docs]class MultiPartitionKey(str): """A multi-dimensional partition key stores the partition key for each dimension. Subclasses the string class to keep partition key type as a string. Contains additional methods to access the partition key for each dimension. Creates a string representation of the partition key for each dimension, separated by a pipe (|). Orders the dimensions by name, to ensure consistent string representation. """ dimension_keys: List[PartitionDimensionKey] = [] def __new__(cls, keys_by_dimension: Mapping[str, str]): check.mapping_param( keys_by_dimension, "partitions_by_dimension", key_type=str, value_type=str ) dimension_keys: List[PartitionDimensionKey] = [ PartitionDimensionKey(dimension, keys_by_dimension[dimension]) for dimension in sorted(list(keys_by_dimension.keys())) ] str_key = super(MultiPartitionKey, cls).__new__( cls, MULTIPARTITION_KEY_DELIMITER.join( [dim_key.partition_key for dim_key in dimension_keys] ), ) str_key.dimension_keys = dimension_keys return str_key def __getnewargs__(self): # When this instance is pickled, replace the argument to __new__ with the # dimension key mapping instead of the string representation. return ({dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys},) @property def keys_by_dimension(self) -> Mapping[str, str]: return {dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys}
class PartitionDimensionDefinition( NamedTuple( "_PartitionDimensionDefinition", [ ("name", str), ("partitions_def", PartitionsDefinition), ], ) ): def __new__( cls, name: str, partitions_def: PartitionsDefinition, ): return super().__new__( cls, name=check.str_param(name, "name"), partitions_def=check.inst_param(partitions_def, "partitions_def", PartitionsDefinition), ) def __eq__(self, other: object) -> bool: return ( isinstance(other, PartitionDimensionDefinition) and self.name == other.name and self.partitions_def == other.partitions_def ) ALLOWED_PARTITION_DIMENSION_TYPES = ( StaticPartitionsDefinition, TimeWindowPartitionsDefinition, DynamicPartitionsDefinition, ) def _check_valid_partitions_dimensions( partitions_dimensions: Mapping[str, PartitionsDefinition] ) -> None: for dim_name, partitions_def in partitions_dimensions.items(): if not any(isinstance(partitions_def, t) for t in ALLOWED_PARTITION_DIMENSION_TYPES): raise DagsterInvalidDefinitionError( f"Invalid partitions definition type {type(partitions_def)}. " "Only the following partitions definition types are supported: " f"{ALLOWED_PARTITION_DIMENSION_TYPES}." ) if isinstance(partitions_def, DynamicPartitionsDefinition) and partitions_def.name is None: raise DagsterInvalidDefinitionError( "DynamicPartitionsDefinition must have a name to be used in a" " MultiPartitionsDefinition." ) if isinstance(partitions_def, StaticPartitionsDefinition): if any( [ INVALID_STATIC_PARTITIONS_KEY_CHARACTERS & set(key) for key in partitions_def.get_partition_keys() ] ): raise DagsterInvalidDefinitionError( f"Invalid character in partition key for dimension {dim_name}. " "A multi-partitions definition cannot contain partition keys with " "the following characters: |, [, ], ," )
[docs]@experimental class MultiPartitionsDefinition(PartitionsDefinition): """Takes the cross-product of partitions from two partitions definitions. For example, with a static partitions definition where the partitions are ["a", "b", "c"] and a daily partitions definition, this partitions definition will have the following partitions: 2020-01-01|a 2020-01-01|b 2020-01-01|c 2020-01-02|a 2020-01-02|b ... Args: partitions_defs (Mapping[str, PartitionsDefinition]): A mapping of dimension name to partitions definition. The total set of partitions will be the cross-product of the partitions from each PartitionsDefinition. Attributes: partitions_defs (Sequence[PartitionDimensionDefinition]): A sequence of PartitionDimensionDefinition objects, each of which contains a dimension name and a PartitionsDefinition. The total set of partitions will be the cross-product of the partitions from each PartitionsDefinition. This sequence is ordered by dimension name, to ensure consistent ordering of the partitions. """ def __init__(self, partitions_defs: Mapping[str, PartitionsDefinition]): if not len(partitions_defs.keys()) == 2: raise DagsterInvalidInvocationError( "Dagster currently only supports multi-partitions definitions with 2 partitions" " definitions. Your multi-partitions definition has" f" {len(partitions_defs.keys())} partitions definitions." ) check.mapping_param( partitions_defs, "partitions_defs", key_type=str, value_type=PartitionsDefinition ) _check_valid_partitions_dimensions(partitions_defs) self._partitions_defs: List[PartitionDimensionDefinition] = sorted( [ PartitionDimensionDefinition(name, partitions_def) for name, partitions_def in partitions_defs.items() ], key=lambda x: x.name, ) @property def partitions_subset_class(self) -> Type["PartitionsSubset"]: return MultiPartitionsSubset @property def partition_dimension_names(self) -> List[str]: return [dim_def.name for dim_def in self._partitions_defs] @property def partitions_defs(self) -> Sequence[PartitionDimensionDefinition]: return self._partitions_defs def get_partition( self, partition_key: str, current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> Partition: if not isinstance(partition_key, MultiPartitionKey): partition_key = self.get_partition_key_from_str(partition_key) partition_key = cast(MultiPartitionKey, partition_key) if partition_key.keys_by_dimension.keys() != set(self.partition_dimension_names): raise DagsterUnknownPartitionError( f"Invalid partition key {partition_key}. The dimensions of the partition key are" " not the dimensions of the partitions definition." ) partitions_by_dimension: Dict[str, Partition] = {} for dimension in self.partitions_defs: partition = dimension.partitions_def.get_partition( partition_key.keys_by_dimension[dimension.name], current_time=current_time, dynamic_partitions_store=dynamic_partitions_store, ) if not partition: check.failed( f"Invalid partition key {partition_key}. The partition key does not exist in" " the partitions definition." ) partitions_by_dimension[dimension.name] = partition return Partition(value=partitions_by_dimension, name=partition_key) def get_partitions( self, current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> Sequence[Partition]: partition_sequences = [ partition_dim.partitions_def.get_partitions( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for partition_dim in self._partitions_defs ] def get_multi_dimensional_partition(partitions_tuple: Tuple[Partition]) -> Partition: check.invariant(len(partitions_tuple) == len(self._partitions_defs)) partitions_by_dimension: Dict[str, Partition] = { self._partitions_defs[i].name: partitions_tuple[i] for i in range(len(partitions_tuple)) } return Partition( value=partitions_by_dimension, name=MultiPartitionKey( { dimension_key: partition.name for dimension_key, partition in partitions_by_dimension.items() } ), ) return [ get_multi_dimensional_partition(partitions_tuple) for partitions_tuple in itertools.product(*partition_sequences) ] def __eq__(self, other): return ( isinstance(other, MultiPartitionsDefinition) and self.partitions_defs == other.partitions_defs ) def __hash__(self): return hash( tuple( [ (partitions_def.name, partitions_def.__repr__()) for partitions_def in self.partitions_defs ] ) ) def __str__(self) -> str: dimension_1 = self._partitions_defs[0] dimension_2 = self._partitions_defs[1] partition_str = ( "Multi-partitioned, with dimensions: \n" f"{dimension_1.name.capitalize()}: {str(dimension_1.partitions_def)} \n" f"{dimension_2.name.capitalize()}: {str(dimension_2.partitions_def)}" ) return partition_str def __repr__(self) -> str: return f"{type(self).__name__}(dimensions={[str(dim) for dim in self.partitions_defs]}" def get_partition_key_from_str(self, partition_key_str: str) -> str: """Given a string representation of a partition key, returns a MultiPartitionKey object.""" check.str_param(partition_key_str, "partition_key_str") partition_key_strs = partition_key_str.split(MULTIPARTITION_KEY_DELIMITER) check.invariant( len(partition_key_strs) == len(self.partitions_defs), ( f"Expected {len(self.partitions_defs)} partition keys in partition key string" f" {partition_key_str}, but got {len(partition_key_strs)}" ), ) return MultiPartitionKey( {dim.name: partition_key_strs[i] for i, dim in enumerate(self._partitions_defs)} ) def _get_primary_and_secondary_dimension( self, ) -> Tuple[PartitionDimensionDefinition, PartitionDimensionDefinition]: # Multipartitions subsets are serialized by primary dimension. If changing # the selection of primary/secondary dimension, will need to also update the # serialization of MultiPartitionsSubsets time_dimensions = [ dim for dim in self.partitions_defs if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) ] if len(time_dimensions) == 1: primary_dimension, secondary_dimension = time_dimensions[0], next( iter([dim for dim in self.partitions_defs if dim != time_dimensions[0]]) ) else: primary_dimension, secondary_dimension = ( self.partitions_defs[0], self.partitions_defs[1], ) return primary_dimension, secondary_dimension @property def primary_dimension(self) -> PartitionDimensionDefinition: return self._get_primary_and_secondary_dimension()[0] @property def secondary_dimension(self) -> PartitionDimensionDefinition: return self._get_primary_and_secondary_dimension()[1] def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]: partition_key = cast(MultiPartitionKey, self.get_partition_key_from_str(partition_key)) tags = {**super().get_tags_for_partition_key(partition_key)} tags.update(get_tags_from_multi_partition_key(partition_key)) return tags @property def time_window_dimension(self) -> PartitionDimensionDefinition: time_window_dims = [ dim for dim in self.partitions_defs if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) ] check.invariant( len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension" ) return next(iter(time_window_dims)) def get_cron_schedule( self, minute_of_hour: Optional[int] = None, hour_of_day: Optional[int] = None, day_of_week: Optional[int] = None, day_of_month: Optional[int] = None, ) -> str: return cast( TimeWindowPartitionsDefinition, self.time_window_dimension.partitions_def ).get_cron_schedule(minute_of_hour, hour_of_day, day_of_week, day_of_month) @property def timezone(self) -> Optional[str]: return cast( TimeWindowPartitionsDefinition, self.time_window_dimension.partitions_def ).timezone def get_multipartition_keys_with_dimension_value( self, dimension_name: str, dimension_partition_key: str, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, current_time: Optional[datetime] = None, ) -> Sequence[MultiPartitionKey]: check.str_param(dimension_name, "dimension_name") check.str_param(dimension_partition_key, "dimension_partition_key") matching_dimensions = [ dimension for dimension in self.partitions_defs if dimension.name == dimension_name ] other_dimensions = [ dimension for dimension in self.partitions_defs if dimension.name != dimension_name ] check.invariant( len(matching_dimensions) == 1, ( f"Dimension {dimension_name} not found in MultiPartitionsDefinition with dimensions" f" {[dim.name for dim in self.partitions_defs]}" ), ) partition_sequences = [ partition_dim.partitions_def.get_partition_keys( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for partition_dim in other_dimensions ] + [[dimension_partition_key]] # Names of partitions dimensions in the same order as partition_sequences partition_dim_names = [dim.name for dim in other_dimensions] + [dimension_name] return [ MultiPartitionKey( { partition_dim_names[i]: partition_key for i, partition_key in enumerate(partitions_tuple) } ) for partitions_tuple in itertools.product(*partition_sequences) ] def get_num_partitions( self, current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> int: # Static partitions definitions can contain duplicate keys (will throw error in 1.3.0) # In the meantime, relying on get_num_partitions to handle duplicates to display # correct counts in Dagit dimension_counts = [ dim.partitions_def.get_num_partitions( current_time=current_time, dynamic_partitions_store=dynamic_partitions_store ) for dim in self.partitions_defs ] return reduce(lambda x, y: x * y, dimension_counts, 1)
class MultiPartitionsSubset(DefaultPartitionsSubset): def __init__( self, partitions_def: MultiPartitionsDefinition, subset: Optional[Set[str]] = None, ): check.inst_param(partitions_def, "partitions_def", MultiPartitionsDefinition) subset = ( set( [ partitions_def.get_partition_key_from_str(key) for key in subset if MULTIPARTITION_KEY_DELIMITER in key ] ) if subset else set() ) super(MultiPartitionsSubset, self).__init__(partitions_def, subset) def with_partition_keys(self, partition_keys: Iterable[str]) -> "MultiPartitionsSubset": return MultiPartitionsSubset( cast(MultiPartitionsDefinition, self._partitions_def), self._subset | set(partition_keys), ) def get_tags_from_multi_partition_key(multi_partition_key: MultiPartitionKey) -> Mapping[str, str]: check.inst_param(multi_partition_key, "multi_partition_key", MultiPartitionKey) return { get_multidimensional_partition_tag(dimension.dimension_name): dimension.partition_key for dimension in multi_partition_key.dimension_keys } def get_multipartition_key_from_tags(tags: Mapping[str, str]) -> str: partitions_by_dimension: Dict[str, str] = {} for tag in tags: if tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX): dimension = tag[len(MULTIDIMENSIONAL_PARTITION_PREFIX) :] partitions_by_dimension[dimension] = tags[tag] return MultiPartitionKey(partitions_by_dimension)