import hashlib
import os
import shutil
import sys
from collections import defaultdict
from contextlib import contextmanager
from typing import IO, TYPE_CHECKING, Generator, Iterator, Mapping, Optional, Sequence, Tuple
from typing_extensions import Final
from watchdog.events import PatternMatchingEventHandler
from watchdog.observers.polling import PollingObserver
from dagster import (
Field,
Float,
StringSource,
_check as check,
)
from dagster._config.config_schema import UserConfigSchema
from dagster._core.execution.compute_logs import mirror_stream_to_file
from dagster._core.storage.pipeline_run import DagsterRun
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._seven import json
from dagster._utils import ensure_dir, ensure_file, touch_file
from .captured_log_manager import (
CapturedLogContext,
CapturedLogData,
CapturedLogManager,
CapturedLogMetadata,
CapturedLogSubscription,
)
from .compute_log_manager import (
MAX_BYTES_FILE_READ,
ComputeIOType,
ComputeLogFileData,
ComputeLogManager,
ComputeLogSubscription,
)
if TYPE_CHECKING:
from dagster._core.storage.cloud_storage_compute_log_manager import LogSubscription
DEFAULT_WATCHDOG_POLLING_TIMEOUT: Final = 2.5
IO_TYPE_EXTENSION: Final[Mapping[ComputeIOType, str]] = {
ComputeIOType.STDOUT: "out",
ComputeIOType.STDERR: "err",
}
MAX_FILENAME_LENGTH: Final = 255
[docs]class LocalComputeLogManager(CapturedLogManager, ComputeLogManager, ConfigurableClass):
"""Stores copies of stdout & stderr for each compute step locally on disk."""
def __init__(
self,
base_dir: str,
polling_timeout: Optional[float] = None,
inst_data: Optional[ConfigurableClassData] = None,
):
self._base_dir = base_dir
self._polling_timeout = check.opt_float_param(
polling_timeout, "polling_timeout", DEFAULT_WATCHDOG_POLLING_TIMEOUT
)
self._subscription_manager = LocalComputeLogSubscriptionManager(self)
self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData)
@property
def inst_data(self) -> Optional[ConfigurableClassData]:
return self._inst_data
@property
def polling_timeout(self) -> float:
return self._polling_timeout
@classmethod
def config_type(cls) -> UserConfigSchema:
return {
"base_dir": StringSource,
"polling_timeout": Field(Float, is_required=False),
}
@classmethod
def from_config_value(
cls, inst_data: Optional[ConfigurableClassData], config_value
) -> "LocalComputeLogManager":
return LocalComputeLogManager(inst_data=inst_data, **config_value)
@contextmanager
def capture_logs(self, log_key: Sequence[str]) -> Generator[CapturedLogContext, None, None]:
outpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT])
errpath = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR])
with mirror_stream_to_file(sys.stdout, outpath), mirror_stream_to_file(sys.stderr, errpath):
yield CapturedLogContext(log_key)
# leave artifact on filesystem so that we know the capture is completed
touch_file(self.complete_artifact_path(log_key))
@contextmanager
def open_log_stream(
self, log_key: Sequence[str], io_type: ComputeIOType
) -> Iterator[Optional[IO]]:
path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
ensure_file(path)
with open(path, "+a", encoding="utf-8") as f:
yield f
def is_capture_complete(self, log_key: Sequence[str]) -> bool:
return os.path.exists(self.complete_artifact_path(log_key))
def get_log_data(
self, log_key: Sequence[str], cursor: Optional[str] = None, max_bytes: Optional[int] = None
) -> CapturedLogData:
stdout_cursor, stderr_cursor = self.parse_cursor(cursor)
stdout, stdout_offset = self._read_bytes(
log_key, ComputeIOType.STDOUT, offset=stdout_cursor, max_bytes=max_bytes
)
stderr, stderr_offset = self._read_bytes(
log_key, ComputeIOType.STDERR, offset=stderr_cursor, max_bytes=max_bytes
)
return CapturedLogData(
log_key=log_key,
stdout=stdout,
stderr=stderr,
cursor=self.build_cursor(stdout_offset, stderr_offset),
)
def get_log_metadata(self, log_key: Sequence[str]) -> CapturedLogMetadata:
return CapturedLogMetadata(
stdout_location=self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]
),
stderr_location=self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]
),
stdout_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDOUT),
stderr_download_url=self.get_captured_log_download_url(log_key, ComputeIOType.STDERR),
)
def delete_logs(
self, log_key: Optional[Sequence[str]] = None, prefix: Optional[Sequence[str]] = None
):
if log_key:
paths = [
self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]),
self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]),
self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True
),
self.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True
),
self.get_captured_local_path(log_key, "complete"),
]
for path in paths:
if os.path.exists(path) and os.path.isfile(path):
os.remove(path)
elif prefix:
dir_to_delete = os.path.join(self._base_dir, *prefix)
if os.path.exists(dir_to_delete) and os.path.isdir(dir_to_delete):
# recursively delete all files in dir
shutil.rmtree(dir_to_delete)
else:
check.failed("Must pass in either `log_key` or `prefix` argument to delete_logs")
def _read_bytes(
self,
log_key: Sequence[str],
io_type: ComputeIOType,
offset: Optional[int] = 0,
max_bytes: Optional[int] = None,
):
path = self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
return self.read_path(path, offset or 0, max_bytes)
def parse_cursor(self, cursor: Optional[str] = None) -> Tuple[int, int]:
# Translates a string cursor into a set of byte offsets for stdout, stderr
if not cursor:
return 0, 0
parts = cursor.split(":")
if not parts or len(parts) != 2:
return 0, 0
stdout, stderr = [int(_) for _ in parts]
return stdout, stderr
def build_cursor(self, stdout_offset: int, stderr_offset: int) -> str:
return f"{stdout_offset}:{stderr_offset}"
def complete_artifact_path(self, log_key):
return self.get_captured_local_path(log_key, "complete")
def read_path(
self,
path: str,
offset: int = 0,
max_bytes: Optional[int] = None,
):
if not os.path.exists(path) or not os.path.isfile(path):
return None, offset
with open(path, "rb") as f:
f.seek(offset, os.SEEK_SET)
if max_bytes is None:
data = f.read()
else:
data = f.read(max_bytes)
new_offset = f.tell()
return data, new_offset
def get_captured_log_download_url(self, log_key, io_type):
check.inst_param(io_type, "io_type", ComputeIOType)
url = "/logs"
for part in log_key:
url = f"{url}/{part}"
return f"{url}/{IO_TYPE_EXTENSION[io_type]}"
def get_captured_local_path(self, log_key: Sequence[str], extension: str, partial=False):
[*namespace, filebase] = log_key
filename = f"{filebase}.{extension}"
if partial:
filename = f"{filename}.partial"
if len(filename) > MAX_FILENAME_LENGTH:
filename = "{}.{}".format(hashlib.md5(filebase.encode("utf-8")).hexdigest(), extension)
return os.path.join(self._base_dir, *namespace, filename)
def subscribe(
self, log_key: Sequence[str], cursor: Optional[str] = None
) -> CapturedLogSubscription:
subscription = CapturedLogSubscription(self, log_key, cursor)
self.on_subscribe(subscription)
return subscription
def unsubscribe(self, subscription):
self.on_unsubscribe(subscription)
###############################################
#
# Methods for the ComputeLogManager interface
#
###############################################
@contextmanager
def _watch_logs(
self, pipeline_run: DagsterRun, step_key: Optional[str] = None
) -> Iterator[None]:
check.inst_param(pipeline_run, "pipeline_run", DagsterRun)
check.opt_str_param(step_key, "step_key")
log_key = self.build_log_key_for_run(
pipeline_run.run_id, step_key or pipeline_run.pipeline_name
)
with self.capture_logs(log_key):
yield
def get_local_path(self, run_id: str, key: str, io_type: ComputeIOType) -> str:
"""Legacy adapter from compute log manager to more generic captured log manager API."""
check.inst_param(io_type, "io_type", ComputeIOType)
log_key = self.build_log_key_for_run(run_id, key)
return self.get_captured_local_path(log_key, IO_TYPE_EXTENSION[io_type])
def read_logs_file(
self,
run_id: str,
key: str,
io_type: ComputeIOType,
cursor: int = 0,
max_bytes: int = MAX_BYTES_FILE_READ,
) -> ComputeLogFileData:
path = self.get_local_path(run_id, key, io_type)
if not os.path.exists(path) or not os.path.isfile(path):
return ComputeLogFileData(path=path, data=None, cursor=0, size=0, download_url=None)
# See: https://docs.python.org/2/library/stdtypes.html#file.tell for Windows behavior
with open(path, "rb") as f:
f.seek(cursor, os.SEEK_SET)
data = f.read(max_bytes)
cursor = f.tell()
stats = os.fstat(f.fileno())
# local download path
download_url = self.download_url(run_id, key, io_type)
return ComputeLogFileData(
path=path,
data=data.decode("utf-8"),
cursor=cursor,
size=stats.st_size,
download_url=download_url,
)
def get_key(self, pipeline_run: DagsterRun, step_key: Optional[str]):
check.inst_param(pipeline_run, "pipeline_run", DagsterRun)
check.opt_str_param(step_key, "step_key")
return step_key or pipeline_run.pipeline_name
def is_watch_completed(self, run_id: str, key: str) -> bool:
log_key = self.build_log_key_for_run(run_id, key)
return self.is_capture_complete(log_key)
def on_watch_start(self, pipeline_run: DagsterRun, step_key: Optional[str]):
pass
def on_watch_finish(self, pipeline_run: DagsterRun, step_key: Optional[str] = None):
check.inst_param(pipeline_run, "pipeline_run", DagsterRun)
check.opt_str_param(step_key, "step_key")
log_key = self.build_log_key_for_run(
pipeline_run.run_id, step_key or pipeline_run.pipeline_name
)
touchpath = self.complete_artifact_path(log_key)
touch_file(touchpath)
def download_url(self, run_id: str, key: str, io_type: ComputeIOType):
check.inst_param(io_type, "io_type", ComputeIOType)
return "/download/{}/{}/{}".format(run_id, key, io_type.value)
def on_subscribe(self, subscription: "LogSubscription") -> None:
self._subscription_manager.add_subscription(subscription)
def on_unsubscribe(self, subscription: "LogSubscription") -> None:
self._subscription_manager.remove_subscription(subscription)
def dispose(self) -> None:
self._subscription_manager.dispose()
class LocalComputeLogSubscriptionManager:
def __init__(self, manager):
self._manager = manager
self._subscriptions = defaultdict(list)
self._watchers = {}
self._observer = None
def add_subscription(self, subscription: "LogSubscription") -> None:
check.inst_param(
subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription)
)
if self.is_complete(subscription):
subscription.fetch()
subscription.complete()
else:
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
self._subscriptions[watch_key].append(subscription)
self.watch(subscription)
def is_complete(self, subscription: "LogSubscription") -> bool:
check.inst_param(
subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription)
)
if isinstance(subscription, ComputeLogSubscription):
return self._manager.is_watch_completed(subscription.run_id, subscription.key)
return self._manager.is_capture_complete(subscription.log_key)
def remove_subscription(self, subscription: "LogSubscription") -> None:
check.inst_param(
subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription)
)
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
if subscription in self._subscriptions[watch_key]:
self._subscriptions[watch_key].remove(subscription)
subscription.complete()
def _log_key(self, subscription: "LogSubscription") -> Sequence[str]:
check.inst_param(
subscription, "subscription", (ComputeLogSubscription, CapturedLogSubscription)
)
if isinstance(subscription, ComputeLogSubscription):
return self._manager.build_log_key_for_run(subscription.run_id, subscription.key)
return subscription.log_key
def _watch_key(self, log_key: Sequence[str]) -> str:
return json.dumps(log_key)
def remove_all_subscriptions(self, log_key: Sequence[str]) -> None:
watch_key = self._watch_key(log_key)
for subscription in self._subscriptions.pop(watch_key, []):
subscription.complete()
def watch(self, subscription: "LogSubscription") -> None:
log_key = self._log_key(subscription)
watch_key = self._watch_key(log_key)
if watch_key in self._watchers:
return
update_paths = [
self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT]),
self._manager.get_captured_local_path(log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR]),
self._manager.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDOUT], partial=True
),
self._manager.get_captured_local_path(
log_key, IO_TYPE_EXTENSION[ComputeIOType.STDERR], partial=True
),
]
complete_paths = [self._manager.complete_artifact_path(log_key)]
directory = os.path.dirname(
self._manager.get_captured_local_path(log_key, ComputeIOType.STDERR),
)
if not self._observer:
self._observer = PollingObserver(self._manager.polling_timeout)
self._observer.start()
ensure_dir(directory)
self._watchers[watch_key] = self._observer.schedule(
LocalComputeLogFilesystemEventHandler(self, log_key, update_paths, complete_paths),
str(directory),
)
def notify_subscriptions(self, log_key: Sequence[str]) -> None:
watch_key = self._watch_key(log_key)
for subscription in self._subscriptions[watch_key]:
subscription.fetch()
def unwatch(self, log_key: Sequence[str], handler) -> None:
watch_key = self._watch_key(log_key)
if watch_key in self._watchers:
self._observer.remove_handler_for_watch(handler, self._watchers[watch_key]) # type: ignore
del self._watchers[watch_key]
def dispose(self) -> None:
if self._observer:
self._observer.stop()
self._observer.join(15)
class LocalComputeLogFilesystemEventHandler(PatternMatchingEventHandler):
def __init__(self, manager, log_key, update_paths, complete_paths):
self.manager = manager
self.log_key = log_key
self.update_paths = update_paths
self.complete_paths = complete_paths
patterns = update_paths + complete_paths
super(LocalComputeLogFilesystemEventHandler, self).__init__(patterns=patterns)
def on_created(self, event):
if event.src_path in self.complete_paths:
self.manager.remove_all_subscriptions(self.log_key)
self.manager.unwatch(self.log_key, self)
def on_modified(self, event):
if event.src_path in self.update_paths:
self.manager.notify_subscriptions(self.log_key)