Skip to content

Commit

Permalink
PyO3: introduce getattr_bound helper
Browse files Browse the repository at this point in the history
  • Loading branch information
tdyas committed Oct 2, 2024
1 parent 3d29076 commit b9919b6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
24 changes: 19 additions & 5 deletions src/rust/engine/src/externs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub fn store_bool(py: Python, val: bool) -> Value {
///
/// Gets an attribute of the given value as the given type.
///
pub fn getattr<'py, T>(value: &'py PyAny, field: &str) -> Result<T, String>
pub fn getattr_bound<'py, T>(value: &Bound<'py, PyAny>, field: &str) -> Result<T, String>
where
T: FromPyObject<'py>,
{
Expand All @@ -185,6 +185,13 @@ where
})
}

pub fn getattr<'py, T>(value: &'py PyAny, field: &str) -> Result<T, String>
where
T: FromPyObject<'py>,
{
getattr_bound(&value.as_borrowed(), field)
}

///
/// Collect the Values contained within an outer Python Iterable PyObject.
///
Expand Down Expand Up @@ -212,19 +219,26 @@ pub fn collect_iterable(value: &PyAny) -> Result<Vec<&PyAny>, String> {
}

/// Read a `FrozenDict[str, T]`.
pub fn getattr_from_str_frozendict<'p, T: FromPyObject<'p>>(
value: &'p PyAny,
pub fn getattr_from_str_frozendict_bound<'py, T: FromPyObject<'py>>(
value: &Bound<'py, PyAny>,
field: &str,
) -> BTreeMap<String, T> {
let frozendict = getattr(value, field).unwrap();
let pydict: &PyDict = getattr(frozendict, "_data").unwrap();
let frozendict: Bound<PyAny> = getattr_bound(value, field).unwrap();
let pydict: Bound<PyDict> = getattr_bound(&frozendict, "_data").unwrap();
pydict
.items()
.into_iter()
.map(|kv_pair| kv_pair.extract().unwrap())
.collect()
}

pub fn getattr_from_str_frozendict<'py, T: FromPyObject<'py>>(
value: &'py PyAny,
field: &str,
) -> BTreeMap<String, T> {
getattr_from_str_frozendict_bound(&value.as_borrowed(), field)
}

pub fn getattr_as_optional_string(value: &PyAny, field: &str) -> PyResult<Option<String>> {
// TODO: It's possible to view a python string as a `Cow<str>`, so we could avoid actually
// cloning in some cases.
Expand Down
12 changes: 6 additions & 6 deletions src/rust/engine/src/nodes/downloaded_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ impl DownloadedFile {
let (url_str, expected_digest, auth_headers, retry_delay_duration, max_attempts) =
Python::with_gil(|py| {
let py_download_file_val = self.0.to_value();
let py_download_file = (*py_download_file_val).as_ref(py);
let url_str: String = externs::getattr(py_download_file, "url")
let py_download_file = (*py_download_file_val).bind(py);
let url_str: String = externs::getattr_bound(py_download_file, "url")
.map_err(|e| format!("Failed to get `url` for field: {e}"))?;
let auth_headers =
externs::getattr_from_str_frozendict(py_download_file, "auth_headers");
externs::getattr_from_str_frozendict_bound(py_download_file, "auth_headers");
let py_file_digest: PyFileDigest =
externs::getattr(py_download_file, "expected_digest")?;
externs::getattr_bound(py_download_file, "expected_digest")?;
let retry_delay_duration: Duration =
externs::getattr(py_download_file, "retry_error_duration")?;
externs::getattr_bound(py_download_file, "retry_error_duration")?;
let max_attempts: NonZeroUsize =
externs::getattr(py_download_file, "max_attempts")?;
externs::getattr_bound(py_download_file, "max_attempts")?;
Ok::<_, String>((
url_str,
py_file_digest.0,
Expand Down
5 changes: 5 additions & 0 deletions src/rust/engine/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ impl Value {
Err(arc_handle) => arc_handle.clone_ref(py),
}
}

/// Bind this value to the given Pythn context as a `pyo3::Bound` smart pointer.
pub fn bind<'py>(&self, py: Python<'py>) -> &Bound<'py, PyAny> {
self.0.bind(py)
}
}

impl workunit_store::Value for Value {
Expand Down

0 comments on commit b9919b6

Please sign in to comment.