Skip to main content

pyo3/conversions/std/
set.rs

1use std::{cmp, collections, hash};
2
3#[cfg(feature = "experimental-inspect")]
4use crate::inspect::{type_hint_subscript, PyStaticExpr};
5#[cfg(feature = "experimental-inspect")]
6use crate::type_object::PyTypeInfo;
7use crate::{
8    conversion::{FromPyObjectOwned, IntoPyObject},
9    types::{
10        any::PyAnyMethods, frozenset::PyFrozenSetMethods, set::PySetMethods, PyFrozenSet, PySet,
11    },
12    Borrowed, Bound, FromPyObject, PyAny, PyErr, Python,
13};
14
15impl<'py, K, S> IntoPyObject<'py> for collections::HashSet<K, S>
16where
17    K: IntoPyObject<'py> + Eq + hash::Hash,
18    S: hash::BuildHasher + Default,
19{
20    type Target = PySet;
21    type Output = Bound<'py, Self::Target>;
22    type Error = PyErr;
23
24    #[cfg(feature = "experimental-inspect")]
25    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::OUTPUT_TYPE);
26
27    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
28        PySet::new(py, self)
29    }
30}
31
32impl<'a, 'py, K, H> IntoPyObject<'py> for &'a collections::HashSet<K, H>
33where
34    &'a K: IntoPyObject<'py> + Eq + hash::Hash,
35    H: hash::BuildHasher,
36{
37    type Target = PySet;
38    type Output = Bound<'py, Self::Target>;
39    type Error = PyErr;
40    #[cfg(feature = "experimental-inspect")]
41    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, <&K>::OUTPUT_TYPE);
42
43    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
44        PySet::new(py, self)
45    }
46}
47
48impl<'py, K, S> FromPyObject<'_, 'py> for collections::HashSet<K, S>
49where
50    K: FromPyObjectOwned<'py> + cmp::Eq + hash::Hash,
51    S: hash::BuildHasher + Default,
52{
53    type Error = PyErr;
54
55    #[cfg(feature = "experimental-inspect")]
56    const INPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::INPUT_TYPE);
57
58    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
59        match ob.cast::<PySet>() {
60            Ok(set) => set
61                .iter()
62                .map(|any| any.extract().map_err(Into::into))
63                .collect(),
64            Err(err) => {
65                if let Ok(frozen_set) = ob.cast::<PyFrozenSet>() {
66                    frozen_set
67                        .iter()
68                        .map(|any| any.extract().map_err(Into::into))
69                        .collect()
70                } else {
71                    Err(PyErr::from(err))
72                }
73            }
74        }
75    }
76}
77
78impl<'py, K> IntoPyObject<'py> for collections::BTreeSet<K>
79where
80    K: IntoPyObject<'py> + cmp::Ord,
81{
82    type Target = PySet;
83    type Output = Bound<'py, Self::Target>;
84    type Error = PyErr;
85
86    #[cfg(feature = "experimental-inspect")]
87    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::OUTPUT_TYPE);
88
89    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
90        PySet::new(py, self)
91    }
92}
93
94impl<'a, 'py, K> IntoPyObject<'py> for &'a collections::BTreeSet<K>
95where
96    &'a K: IntoPyObject<'py> + cmp::Ord,
97    K: 'a,
98{
99    type Target = PySet;
100    type Output = Bound<'py, Self::Target>;
101    type Error = PyErr;
102
103    #[cfg(feature = "experimental-inspect")]
104    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, <&K>::OUTPUT_TYPE);
105
106    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
107        PySet::new(py, self)
108    }
109}
110
111impl<'py, K> FromPyObject<'_, 'py> for collections::BTreeSet<K>
112where
113    K: FromPyObjectOwned<'py> + cmp::Ord,
114{
115    type Error = PyErr;
116
117    #[cfg(feature = "experimental-inspect")]
118    const INPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::INPUT_TYPE);
119
120    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
121        match ob.cast::<PySet>() {
122            Ok(set) => set
123                .iter()
124                .map(|any| any.extract().map_err(Into::into))
125                .collect(),
126            Err(err) => {
127                if let Ok(frozen_set) = ob.cast::<PyFrozenSet>() {
128                    frozen_set
129                        .iter()
130                        .map(|any| any.extract().map_err(Into::into))
131                        .collect()
132                } else {
133                    Err(PyErr::from(err))
134                }
135            }
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use crate::types::{any::PyAnyMethods, PyFrozenSet, PySet};
143    use crate::{IntoPyObject, Python};
144    use std::collections::{BTreeSet, HashSet};
145
146    #[test]
147    fn test_extract_hashset() {
148        Python::attach(|py| {
149            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
150            let hash_set: HashSet<usize> = set.extract().unwrap();
151            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
152
153            let set = PyFrozenSet::new(py, [1, 2, 3, 4, 5]).unwrap();
154            let hash_set: HashSet<usize> = set.extract().unwrap();
155            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
156        });
157    }
158
159    #[test]
160    fn test_extract_btreeset() {
161        Python::attach(|py| {
162            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
163            let hash_set: BTreeSet<usize> = set.extract().unwrap();
164            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
165
166            let set = PyFrozenSet::new(py, [1, 2, 3, 4, 5]).unwrap();
167            let hash_set: BTreeSet<usize> = set.extract().unwrap();
168            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
169        });
170    }
171
172    #[test]
173    fn test_set_into_pyobject() {
174        Python::attach(|py| {
175            let bt: BTreeSet<u64> = [1, 2, 3, 4, 5].iter().cloned().collect();
176            let hs: HashSet<u64> = [1, 2, 3, 4, 5].iter().cloned().collect();
177
178            let bto = (&bt).into_pyobject(py).unwrap();
179            let hso = (&hs).into_pyobject(py).unwrap();
180
181            assert_eq!(bt, bto.extract().unwrap());
182            assert_eq!(hs, hso.extract().unwrap());
183        });
184    }
185}