diff --git a/src/rust/engine/src/externs/mod.rs b/src/rust/engine/src/externs/mod.rs index 92d615854e5..e7bce197b20 100644 --- a/src/rust/engine/src/externs/mod.rs +++ b/src/rust/engine/src/externs/mod.rs @@ -167,7 +167,7 @@ pub fn store_bool(py: Python, val: bool) -> Value { /// /// Gets an attribute of the given value as the given type. /// -pub fn getattr<'py, T>(value: &'py PyAny, field: &str) -> Result +pub fn getattr_bound<'py, T>(value: &Bound<'py, PyAny>, field: &str) -> Result where T: FromPyObject<'py>, { @@ -185,6 +185,13 @@ where }) } +pub fn getattr<'py, T>(value: &'py PyAny, field: &str) -> Result +where + T: FromPyObject<'py>, +{ + getattr_bound(&value.as_borrowed(), field) +} + /// /// Collect the Values contained within an outer Python Iterable PyObject. /// @@ -212,12 +219,12 @@ pub fn collect_iterable(value: &PyAny) -> Result, String> { } /// Read a `FrozenDict[str, T]`. -pub fn getattr_from_str_frozendict<'p, T: FromPyObject<'p>>( - value: &'p PyAny, +pub fn getattr_from_str_frozendict_bound<'py, T: FromPyObject<'py>>( + value: &Bound<'py, PyAny>, field: &str, ) -> BTreeMap { - let frozendict = getattr(value, field).unwrap(); - let pydict: &PyDict = getattr(frozendict, "_data").unwrap(); + let frozendict: Bound = getattr_bound(value, field).unwrap(); + let pydict: Bound = getattr_bound(&frozendict, "_data").unwrap(); pydict .items() .into_iter() @@ -225,6 +232,13 @@ pub fn getattr_from_str_frozendict<'p, T: FromPyObject<'p>>( .collect() } +pub fn getattr_from_str_frozendict<'py, T: FromPyObject<'py>>( + value: &'py PyAny, + field: &str, +) -> BTreeMap { + getattr_from_str_frozendict_bound(&value.as_borrowed(), field) +} + pub fn getattr_as_optional_string(value: &PyAny, field: &str) -> PyResult> { // TODO: It's possible to view a python string as a `Cow`, so we could avoid actually // cloning in some cases. diff --git a/src/rust/engine/src/nodes/downloaded_file.rs b/src/rust/engine/src/nodes/downloaded_file.rs index 80a7e97771a..d0f0c499606 100644 --- a/src/rust/engine/src/nodes/downloaded_file.rs +++ b/src/rust/engine/src/nodes/downloaded_file.rs @@ -98,17 +98,17 @@ impl DownloadedFile { let (url_str, expected_digest, auth_headers, retry_delay_duration, max_attempts) = Python::with_gil(|py| { let py_download_file_val = self.0.to_value(); - let py_download_file = (*py_download_file_val).as_ref(py); - let url_str: String = externs::getattr(py_download_file, "url") + let py_download_file = (*py_download_file_val).bind(py); + let url_str: String = externs::getattr_bound(py_download_file, "url") .map_err(|e| format!("Failed to get `url` for field: {e}"))?; let auth_headers = - externs::getattr_from_str_frozendict(py_download_file, "auth_headers"); + externs::getattr_from_str_frozendict_bound(py_download_file, "auth_headers"); let py_file_digest: PyFileDigest = - externs::getattr(py_download_file, "expected_digest")?; + externs::getattr_bound(py_download_file, "expected_digest")?; let retry_delay_duration: Duration = - externs::getattr(py_download_file, "retry_error_duration")?; + externs::getattr_bound(py_download_file, "retry_error_duration")?; let max_attempts: NonZeroUsize = - externs::getattr(py_download_file, "max_attempts")?; + externs::getattr_bound(py_download_file, "max_attempts")?; Ok::<_, String>(( url_str, py_file_digest.0, diff --git a/src/rust/engine/src/python.rs b/src/rust/engine/src/python.rs index 2f01cd9245a..a6e58cfeb04 100644 --- a/src/rust/engine/src/python.rs +++ b/src/rust/engine/src/python.rs @@ -312,6 +312,11 @@ impl Value { Err(arc_handle) => arc_handle.clone_ref(py), } } + + /// Bind this value to the given Pythn context as a `pyo3::Bound` smart pointer. + pub fn bind<'py>(&self, py: Python<'py>) -> &Bound<'py, PyAny> { + self.0.bind(py) + } } impl workunit_store::Value for Value {