Skip to content

Commit

Permalink
Make Datasets Pathlike (#36947)
Browse files Browse the repository at this point in the history
This makes datasets inherit from os.Pathlike so they can directly be used by
the Object Storage API.
  • Loading branch information
bolkedebruin authored Jan 23, 2024
1 parent 0381f7f commit f3b7cfc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
6 changes: 5 additions & 1 deletion airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# under the License.
from __future__ import annotations

import os
from typing import Any, ClassVar
from urllib.parse import urlsplit

import attr


@attr.define()
class Dataset:
class Dataset(os.PathLike):
"""A Dataset is used for marking data dependencies between workflows."""

uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
Expand All @@ -42,3 +43,6 @@ def _check_uri(self, attr, uri: str):
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")

def __fspath__(self):
return self.uri
8 changes: 8 additions & 0 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

import os

import pytest

from airflow.datasets import Dataset
Expand Down Expand Up @@ -46,3 +48,9 @@ def test_uri_with_scheme():
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])


def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
9 changes: 9 additions & 0 deletions tests/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from fsspec.implementations.local import LocalFileSystem
from fsspec.utils import stringify_path

from airflow.datasets import Dataset
from airflow.io import _register_filesystems, get_fs
from airflow.io.path import ObjectStoragePath
from airflow.io.store import _STORE_CACHE, ObjectStore, attach
Expand Down Expand Up @@ -309,3 +310,11 @@ def test_backwards_compat(self):
finally:
# Reset the cache to avoid side effects
_register_filesystems.cache_clear()

def test_dataset(self):
p = "s3"
f = "/tmp/foo"
i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"})
o = ObjectStoragePath(i)
assert o.protocol == p
assert o.path == f

0 comments on commit f3b7cfc

Please sign in to comment.