Skip to content

Commit df8a1e9

Browse files
authored
Simplify construction of Ref and OutRef directly (#3730)
1 parent 21708e9 commit df8a1e9

5 files changed

Lines changed: 148 additions & 8 deletions

File tree

crates/libs/core/src/out_ref.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ impl<'a, T: Type<T>> From<&'a mut T::Default> for OutRef<'a, T> {
2929
unsafe { core::mem::transmute(from) }
3030
}
3131
}
32+
33+
impl<T: Type<T>> Default for OutRef<'_, T> {
34+
fn default() -> Self {
35+
unsafe { core::mem::zeroed() }
36+
}
37+
}

crates/libs/core/src/ref.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,43 @@ impl<T: Type<T>> Ref<'_, T> {
4646
}
4747
}
4848

49+
impl<T: Type<T>> Default for Ref<'_, T> {
50+
fn default() -> Self {
51+
unsafe { core::mem::zeroed() }
52+
}
53+
}
54+
4955
impl<T: Type<T>> core::ops::Deref for Ref<'_, T> {
5056
type Target = T::Default;
5157
fn deref(&self) -> &Self::Target {
5258
unsafe { transmute(&self.0) }
5359
}
5460
}
5561

56-
impl<'a, T: Type<T>> From<&'a T::Default> for Ref<'a, T> {
57-
fn from(from: &'a T::Default) -> Self {
62+
impl<'a, T: Type<T, InterfaceType>> From<&'a Option<T>> for Ref<'a, T>
63+
where
64+
T: TypeKind<TypeKind = InterfaceType>,
65+
{
66+
fn from(from: &'a Option<T>) -> Self {
67+
unsafe { core::mem::transmute_copy(from) }
68+
}
69+
}
70+
71+
impl<'a, T: Type<T, InterfaceType>> From<Option<&'a T>> for Ref<'a, T>
72+
where
73+
T: TypeKind<TypeKind = InterfaceType>,
74+
{
75+
fn from(from: Option<&'a T>) -> Self {
76+
if let Some(from) = from {
77+
unsafe { core::mem::transmute_copy(from) }
78+
} else {
79+
unsafe { core::mem::zeroed() }
80+
}
81+
}
82+
}
83+
84+
impl<'a, T: Type<T>> From<&'a T> for Ref<'a, T> {
85+
fn from(from: &'a T) -> Self {
5886
unsafe { core::mem::transmute_copy(from) }
5987
}
6088
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#![allow(non_camel_case_types)]
2+
3+
use windows::Win32::Foundation::E_POINTER;
4+
use windows_core::*;
5+
6+
#[interface("6120f260-090f-4b55-a09f-3cce3ffb35a4")]
7+
unsafe trait ITest: IUnknown {
8+
fn get(&self) -> usize;
9+
}
10+
11+
#[implement(ITest)]
12+
struct Test(usize);
13+
14+
impl Drop for Test {
15+
fn drop(&mut self) {
16+
println!("Drop for Test({})", self.0);
17+
}
18+
}
19+
20+
impl ITest_Impl for Test_Impl {
21+
unsafe fn get(&self) -> usize {
22+
self.0
23+
}
24+
}
25+
26+
fn call_hstring(input: Ref<HSTRING>) -> HSTRING {
27+
input.clone()
28+
}
29+
30+
fn call_object(input: Ref<ITest>) -> Option<ITest> {
31+
input.clone()
32+
}
33+
34+
#[test]
35+
fn test_ref() {
36+
assert_eq!(call_hstring(h!("test").into()), "test");
37+
38+
// Drop lifetime test
39+
{
40+
let test: ITest = Test(12).into();
41+
let clone = call_object((&test).into());
42+
drop(test);
43+
assert_eq!(unsafe { clone.unwrap().get() }, 12);
44+
}
45+
46+
// Inline test (`get` is the only unsafe call)
47+
unsafe {
48+
assert_eq!(
49+
call_object((&ITest::from(Test(23))).into()).unwrap().get(),
50+
23
51+
);
52+
}
53+
54+
// From &T
55+
{
56+
let test: ITest = Test(34).into();
57+
let test = call_object((&test).into());
58+
assert_eq!(unsafe { test.unwrap().get() }, 34);
59+
}
60+
61+
// From &Option<T>
62+
{
63+
let test: Option<ITest> = Some(Test(45).into());
64+
let test = call_object((&test).into());
65+
assert_eq!(unsafe { test.unwrap().get() }, 45);
66+
}
67+
68+
// From Option<&T>
69+
{
70+
let test: ITest = Test(56).into();
71+
let test = call_object(Some(&test).into());
72+
assert_eq!(unsafe { test.unwrap().get() }, 56);
73+
}
74+
}
75+
76+
fn return_hstring(input: Ref<HSTRING>, output: OutRef<HSTRING>) -> Result<()> {
77+
output.write(input.clone())
78+
}
79+
80+
fn return_object(input: Ref<ITest>, output: OutRef<ITest>) -> Result<()> {
81+
output.write(input.clone())
82+
}
83+
84+
#[test]
85+
fn test_out_ref() {
86+
let mut result = HSTRING::new();
87+
return_hstring(h!("test").into(), (&mut result).into()).unwrap();
88+
assert_eq!(result, "test");
89+
90+
// input and output
91+
let input: ITest = Test(11).into();
92+
let mut output = None;
93+
return_object((&input).into(), (&mut output).into()).unwrap();
94+
drop(input);
95+
assert!(output.is_some());
96+
assert_eq!(unsafe { output.unwrap().get() }, 11);
97+
98+
// "None" input and output
99+
let mut output = None;
100+
return_object((&None).into(), (&mut output).into()).unwrap();
101+
assert!(output.is_none());
102+
103+
// default (optional) input and output
104+
let error = return_object(Ref::default(), OutRef::default()).unwrap_err();
105+
assert_eq!(error.code(), E_POINTER);
106+
}

crates/tests/winrt/noexcept/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ pub fn consume(test: &ITest) -> Result<()> {
99
fn consume(test: Ref<ITest>) -> HRESULT;
1010
}
1111

12-
unsafe { consume(std::mem::transmute_copy(test)).ok() }
12+
unsafe { consume(test.into()).ok() }
1313
}
1414

1515
pub fn produce() -> Result<ITest> {
1616
unsafe extern "system" {
17-
fn produce(test: *mut *mut std::ffi::c_void) -> HRESULT;
17+
fn produce(test: OutRef<ITest>) -> HRESULT;
1818
}
1919

2020
unsafe {
2121
let mut test = None;
22-
produce(&mut test as *mut _ as *mut _).ok()?;
22+
produce((&mut test).into()).ok()?;
2323
Type::from_default(&test)
2424
}
2525
}

crates/tests/winrt/ref_params/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ pub fn consume(test: &ITest) -> Result<()> {
99
fn consume(test: Ref<ITest>) -> HRESULT;
1010
}
1111

12-
unsafe { consume(std::mem::transmute_copy(test)).ok() }
12+
unsafe { consume(test.into()).ok() }
1313
}
1414

1515
pub fn produce() -> Result<ITest> {
1616
unsafe extern "system" {
17-
fn produce(test: *mut *mut std::ffi::c_void) -> HRESULT;
17+
fn produce(test: OutRef<ITest>) -> HRESULT;
1818
}
1919

2020
unsafe {
2121
let mut test = None;
22-
produce(&mut test as *mut _ as *mut _).ok()?;
22+
produce((&mut test).into()).ok()?;
2323
Type::from_default(&test)
2424
}
2525
}

0 commit comments

Comments
 (0)