diff --git a/src/Tests/FunctionalTests/Async/Program.cs b/src/Tests/FunctionalTests/Async/Program.cs index e85942f68..fa43cc064 100644 --- a/src/Tests/FunctionalTests/Async/Program.cs +++ b/src/Tests/FunctionalTests/Async/Program.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; using TestComponentCSharp; using Windows.Foundation; @@ -111,6 +112,7 @@ return 111; } + // Test WindowsRuntimeExternalArrayBuffer CCW (created via AsBuffer()) var arr = new byte[100]; var buffer = arr.AsBuffer(); ptr = WindowsRuntimeMarshal.ConvertToUnmanaged(buffer); @@ -125,6 +127,20 @@ return 113; } + // Test WindowsRuntimePinnedArrayBuffer CCW (created via WindowsRuntimeBuffer.Create()) + var pinnedBuffer = WindowsRuntimeBuffer.Create(100); + ptr = WindowsRuntimeMarshal.ConvertToUnmanaged(pinnedBuffer); + if (ptr is null) + { + return 128; + } + + if (Marshal.QueryInterface((nint)ptr, typeof(IBuffer).GUID, out ptr2) != 0 || + ptr2 == IntPtr.Zero) + { + return 129; + } + var asyncOperation = randomAccessStream.ReadAsync(buffer, 50, InputStreamOptions.Partial); ptr = WindowsRuntimeMarshal.ConvertToUnmanaged(asyncOperation); if (ptr is null) @@ -162,6 +178,154 @@ return 117; } +// Test stream adapter span/memory overrides using InMemoryRandomAccessStream +{ + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var adaptedStream = new InMemoryRandomAccessStream().AsStream(); + + // Test Write(ReadOnlySpan) and Read(Span) + adaptedStream.Write(new ReadOnlySpan(data)); + adaptedStream.Seek(0, SeekOrigin.Begin); + + Span spanRead = new byte[256]; + int spanBytesRead = adaptedStream.Read(spanRead); + + if (spanBytesRead != 256) + { + return 118; + } + + if (!data.SequenceEqual(spanRead)) + { + return 119; + } + + // Test WriteAsync(ReadOnlyMemory) and ReadAsync(Memory) + adaptedStream.Seek(0, SeekOrigin.Begin); + await adaptedStream.WriteAsync(new ReadOnlyMemory(data)); + adaptedStream.Seek(0, SeekOrigin.Begin); + + Memory memoryRead = new byte[256]; + int memoryBytesRead = await adaptedStream.ReadAsync(memoryRead); + + if (memoryBytesRead != 256) + { + return 120; + } + + if (!data.SequenceEqual(memoryRead.Span)) + { + return 121; + } + + // Test ReadByte/WriteByte (which delegate to span overrides) + adaptedStream.Seek(0, SeekOrigin.Begin); + adaptedStream.WriteByte(0xAB); + adaptedStream.WriteByte(0xCD); + adaptedStream.Seek(0, SeekOrigin.Begin); + + if (adaptedStream.ReadByte() != 0xAB) + { + return 122; + } + + if (adaptedStream.ReadByte() != 0xCD) + { + return 123; + } + + // Test empty span/memory operations + if (adaptedStream.Read(Span.Empty) != 0) + { + return 124; + } + + adaptedStream.Write(ReadOnlySpan.Empty); + + if (await adaptedStream.ReadAsync(Memory.Empty) != 0) + { + return 125; + } + + await adaptedStream.WriteAsync(ReadOnlyMemory.Empty); + + // Test cancellation for memory-based async operations + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + try + { + _ = await adaptedStream.ReadAsync(new byte[256].AsMemory(), cts.Token); + return 126; + } + catch (OperationCanceledException) + { + } + + try + { + await adaptedStream.WriteAsync(new byte[256].AsMemory(), cts.Token); + return 127; + } + catch (OperationCanceledException) + { + } +} + +// Test writing each managed buffer type to a native WinRT stream (exercises CCW interop) +{ + byte[] testData = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]; + + // Test WindowsRuntimeExternalArrayBuffer (from AsBuffer()) written to native stream + using var stream1 = new InMemoryRandomAccessStream(); + IBuffer externalArrayBuffer = testData.AsBuffer(); + await stream1.WriteAsync(externalArrayBuffer); + stream1.Seek(0); + + byte[] read1 = new byte[8]; + IBuffer readBuffer1 = read1.AsBuffer(); + await stream1.ReadAsync(readBuffer1, 8, InputStreamOptions.None); + if (!testData.SequenceEqual(read1)) + { + return 130; + } + + // Test WindowsRuntimePinnedArrayBuffer (from WindowsRuntimeBuffer.Create()) written to native stream + using var stream2 = new InMemoryRandomAccessStream(); + IBuffer pinnedArrayBuffer = WindowsRuntimeBuffer.Create(testData); + await stream2.WriteAsync(pinnedArrayBuffer); + stream2.Seek(0); + + byte[] read2 = new byte[8]; + IBuffer readBuffer2 = read2.AsBuffer(); + await stream2.ReadAsync(readBuffer2, 8, InputStreamOptions.None); + if (!testData.SequenceEqual(read2)) + { + return 131; + } + + // Test WindowsRuntimePinnedMemoryBuffer (created internally by the stream adapter when + // using span/memory-based Write, which pins the data and wraps it in a PinnedMemoryBuffer + // before passing it as an IBuffer CCW to the native WinRT stream's WriteAsync) + using var stream3 = new InMemoryRandomAccessStream(); + using var adaptedStream3 = stream3.AsStream(); + adaptedStream3.Write(new ReadOnlySpan(testData)); + adaptedStream3.Dispose(); + + stream3.Seek(0); + + byte[] read3 = new byte[8]; + IBuffer readBuffer3 = read3.AsBuffer(); + await stream3.ReadAsync(readBuffer3, 8, InputStreamOptions.None); + if (!testData.SequenceEqual(read3)) + { + return 132; + } +} + return 100; static async Task InvokeAddAsync(Class instance, int lhs, int rhs) diff --git a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs index 13470b8e0..8edcc8415 100644 --- a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs +++ b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; @@ -842,6 +843,387 @@ public void TestBuffer() Assert.IsTrue(arr1[1] == arr2[1]); } + [TestMethod] + public void TestStreamReadSpan() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.Write(data, 0, data.Length); + stream.Seek(0, SeekOrigin.Begin); + + Span read = new byte[256]; + int bytesRead = stream.Read(read); + + Assert.AreEqual(256, bytesRead); + CollectionAssert.AreEqual(data, read.ToArray()); + } + + [TestMethod] + public void TestStreamReadSpanPartial() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.Write(data, 0, data.Length); + stream.Seek(0, SeekOrigin.Begin); + + Span read = new byte[64]; + int bytesRead = stream.Read(read); + + Assert.AreEqual(64, bytesRead); + CollectionAssert.AreEqual(data[..64], read.ToArray()); + } + + [TestMethod] + public void TestStreamReadSpanEmpty() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + int bytesRead = stream.Read(Span.Empty); + Assert.AreEqual(0, bytesRead); + } + + [TestMethod] + public void TestStreamWriteSpan() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.Write(new ReadOnlySpan(data)); + stream.Seek(0, SeekOrigin.Begin); + + byte[] read = new byte[256]; + stream.Read(read, 0, read.Length); + CollectionAssert.AreEqual(data, read); + } + + [TestMethod] + public void TestStreamWriteSpanEmpty() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.Write(ReadOnlySpan.Empty); + Assert.AreEqual(0L, stream.Length); + } + + [TestMethod] + public void TestStreamReadAsyncMemory() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync(data, 0, data.Length); + stream.Seek(0, SeekOrigin.Begin); + + Memory read = new byte[256]; + int bytesRead = await stream.ReadAsync(read); + + Assert.AreEqual(256, bytesRead); + CollectionAssert.AreEqual(data, read.ToArray()); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamReadAsyncMemoryPartial() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync(data, 0, data.Length); + stream.Seek(0, SeekOrigin.Begin); + + Memory read = new byte[64]; + int bytesRead = await stream.ReadAsync(read); + + Assert.AreEqual(64, bytesRead); + CollectionAssert.AreEqual(data[..64], read.ToArray()); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamReadAsyncMemoryEmpty() + { + async Task TestAsync() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + int bytesRead = await stream.ReadAsync(Memory.Empty); + Assert.AreEqual(0, bytesRead); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamWriteAsyncMemory() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync(new ReadOnlyMemory(data)); + stream.Seek(0, SeekOrigin.Begin); + + byte[] read = new byte[256]; + await stream.ReadAsync(read, 0, read.Length); + CollectionAssert.AreEqual(data, read); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamWriteAsyncMemoryEmpty() + { + async Task TestAsync() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync(ReadOnlyMemory.Empty); + Assert.AreEqual(0L, stream.Length); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamReadAsyncMemoryWithCancellation() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + Memory buffer = new byte[256]; + bool threwCancellation = false; + + try + { + stream.ReadAsync(buffer, cts.Token).AsTask().GetAwaiter().GetResult(); + } + catch (OperationCanceledException) + { + threwCancellation = true; + } + + Assert.IsTrue(threwCancellation); + } + + [TestMethod] + public void TestStreamWriteAsyncMemoryWithCancellation() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + ReadOnlyMemory buffer = new byte[256]; + bool threwCancellation = false; + + try + { + stream.WriteAsync(buffer, cts.Token).AsTask().GetAwaiter().GetResult(); + } + catch (OperationCanceledException) + { + threwCancellation = true; + } + + Assert.IsTrue(threwCancellation); + } + + [TestMethod] + public void TestStreamReadByteAfterSpanWrite() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.Write(new ReadOnlySpan([0xAB, 0xCD])); + stream.Seek(0, SeekOrigin.Begin); + + Assert.AreEqual(0xAB, stream.ReadByte()); + Assert.AreEqual(0xCD, stream.ReadByte()); + Assert.AreEqual(-1, stream.ReadByte()); + } + + [TestMethod] + public void TestStreamWriteByteAndReadSpan() + { + using var stream = new InMemoryRandomAccessStream().AsStream(); + stream.WriteByte(0xAB); + stream.WriteByte(0xCD); + stream.Seek(0, SeekOrigin.Begin); + + Span read = new byte[2]; + int bytesRead = stream.Read(read); + + Assert.AreEqual(2, bytesRead); + Assert.AreEqual((byte)0xAB, read[0]); + Assert.AreEqual((byte)0xCD, read[1]); + } + + [TestMethod] + public void TestStreamSpanAndMemoryRoundTrip() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[1024]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + + // Write via span (sync) + stream.Write(new ReadOnlySpan(data, 0, 512)); + + // Write via memory (async) + await stream.WriteAsync(new ReadOnlyMemory(data, 512, 512)); + + stream.Seek(0, SeekOrigin.Begin); + + // Read via memory (async) + Memory readFirst = new byte[512]; + int bytesRead1 = await stream.ReadAsync(readFirst); + Assert.AreEqual(512, bytesRead1); + + // Read via span (sync) + Span readSecond = new byte[512]; + int bytesRead2 = stream.Read(readSecond); + Assert.AreEqual(512, bytesRead2); + + // Verify round-trip + CollectionAssert.AreEqual(data[..512], readFirst.ToArray()); + CollectionAssert.AreEqual(data[512..], readSecond.ToArray()); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + /// + /// A backed by unmanaged memory, used to test the pinned memory + /// code path in and + /// (i.e. the slow + /// path when + /// returns ). + /// + unsafe class UnmanagedMemoryManager : System.Buffers.MemoryManager + { + private byte* _pointer; + private readonly int _length; + + public UnmanagedMemoryManager(int length) + { + _pointer = (byte*)NativeMemory.AllocZeroed((nuint)length); + _length = length; + } + + public override Span GetSpan() => new(_pointer, _length); + + public override MemoryHandle Pin(int elementIndex = 0) => new(_pointer + elementIndex); + + public override void Unpin() { } + + protected override void Dispose(bool disposing) + { + if (_pointer is not null) + { + NativeMemory.Free(_pointer); + _pointer = null; + } + } + } + + [TestMethod] + public void TestStreamReadAsyncUnmanagedMemory() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync(data, 0, data.Length); + stream.Seek(0, SeekOrigin.Begin); + + using var manager = new UnmanagedMemoryManager(256); + Memory read = manager.Memory; + int bytesRead = await stream.ReadAsync(read); + + Assert.AreEqual(256, bytesRead); + CollectionAssert.AreEqual(data, read.ToArray()); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamWriteAsyncUnmanagedMemory() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + using var manager = new UnmanagedMemoryManager(256); + data.AsSpan().CopyTo(manager.Memory.Span); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + await stream.WriteAsync((ReadOnlyMemory)manager.Memory); + stream.Seek(0, SeekOrigin.Begin); + + byte[] read = new byte[256]; + await stream.ReadAsync(read, 0, read.Length); + CollectionAssert.AreEqual(data, read); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + + [TestMethod] + public void TestStreamUnmanagedMemoryRoundTrip() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[512]; + random.NextBytes(data); + + using var stream = new InMemoryRandomAccessStream().AsStream(); + + // Write via unmanaged memory (exercises pinned memory buffer path) + using var writeManager = new UnmanagedMemoryManager(512); + data.AsSpan().CopyTo(writeManager.Memory.Span); + await stream.WriteAsync((ReadOnlyMemory)writeManager.Memory); + + stream.Seek(0, SeekOrigin.Begin); + + // Read via unmanaged memory (exercises pinned memory buffer path) + using var readManager = new UnmanagedMemoryManager(512); + int bytesRead = await stream.ReadAsync(readManager.Memory); + + Assert.AreEqual(512, bytesRead); + CollectionAssert.AreEqual(data, readManager.Memory.ToArray()); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + #endif async Task TestStorageFileAsync() @@ -888,6 +1270,34 @@ public void TestWriteBuffer() Assert.IsTrue(InvokeWriteBufferAsync().Wait(1000)); } + [TestMethod] + public void TestWriteBufferPinnedArrayBuffer() + { + async Task TestAsync() + { + var random = new Random(42); + byte[] data = new byte[256]; + random.NextBytes(data); + + // WindowsRuntimeBuffer.Create() creates a WindowsRuntimePinnedArrayBuffer, + // which exercises a separate CCW code path than AsBuffer() (which creates + // a WindowsRuntimeExternalArrayBuffer). This test verifies that the pinned + // array buffer CCW is correctly used when writing to a native WinRT stream. + using var stream = new InMemoryRandomAccessStream(); + IBuffer buffer = WindowsRuntimeBuffer.Create(data); + await stream.WriteAsync(buffer); + + stream.Seek(0); + + byte[] readData = new byte[256]; + IBuffer readBuffer = readData.AsBuffer(); + await stream.ReadAsync(readBuffer, 256, InputStreamOptions.None); + CollectionAssert.AreEqual(data, readData); + } + + Assert.IsTrue(TestAsync().Wait(5000)); + } + [TestMethod] public unsafe void TestUri() { @@ -1856,11 +2266,18 @@ public unsafe void TestCCWMarshaler() Marshal.ThrowExceptionForHR(Marshal.QueryInterface((IntPtr)ccw.GetThisPtrUnsafe(), in IID_IMarshal, out var marshalCCW)); Assert.AreNotEqual(IntPtr.Zero, marshalCCW); + // Test WindowsRuntimeExternalArrayBuffer CCW (created via AsBuffer()) var array = new byte[] { 0x01 }; var buff = array.AsBuffer(); using WindowsRuntimeObjectReferenceValue ccw2 = WindowsRuntimeInterfaceMarshaller.ConvertToUnmanaged(buff, typeof(IBuffer).GUID); Marshal.ThrowExceptionForHR(Marshal.QueryInterface((IntPtr)ccw2.GetThisPtrUnsafe(), in IID_IMarshal, out var marshalCCW2)); Assert.AreNotEqual(IntPtr.Zero, marshalCCW2); + + // Test WindowsRuntimePinnedArrayBuffer CCW (created via WindowsRuntimeBuffer.Create()) + var pinnedBuff = WindowsRuntimeBuffer.Create(new byte[] { 0x01 }); + using WindowsRuntimeObjectReferenceValue ccw3 = WindowsRuntimeInterfaceMarshaller.ConvertToUnmanaged(pinnedBuff, typeof(IBuffer).GUID); + Marshal.ThrowExceptionForHR(Marshal.QueryInterface((IntPtr)ccw3.GetThisPtrUnsafe(), in IID_IMarshal, out var marshalCCW3)); + Assert.AreNotEqual(IntPtr.Zero, marshalCCW3); } [TestMethod] diff --git a/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimeExternalArrayBuffer.cs b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimeExternalArrayBuffer.cs index c1908daa7..3ea7bdf7d 100644 --- a/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimeExternalArrayBuffer.cs +++ b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimeExternalArrayBuffer.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.ComponentModel; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using WindowsRuntime; @@ -78,11 +77,7 @@ static WindowsRuntimeExternalArrayBufferInterfaceEntriesImpl() /// /// A custom implementation for . /// -[Obsolete(WindowsRuntimeConstants.PrivateImplementationDetailObsoleteMessage, - DiagnosticId = WindowsRuntimeConstants.PrivateImplementationDetailObsoleteDiagnosticId, - UrlFormat = WindowsRuntimeConstants.CsWinRTDiagnosticsUrlFormat)] -[EditorBrowsable(EditorBrowsableState.Never)] -public sealed unsafe class WindowsRuntimeExternalArrayBufferComWrappersMarshallerAttribute : WindowsRuntimeComWrappersMarshallerAttribute +file sealed unsafe class WindowsRuntimeExternalArrayBufferComWrappersMarshallerAttribute : WindowsRuntimeComWrappersMarshallerAttribute { /// public override void* GetOrCreateComInterfaceForObject(object value) diff --git a/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedArrayBuffer.cs b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedArrayBuffer.cs index 46dfaab91..2aea75ad9 100644 --- a/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedArrayBuffer.cs +++ b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedArrayBuffer.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.ComponentModel; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using WindowsRuntime; @@ -78,11 +77,7 @@ static WindowsRuntimePinnedArrayBufferInterfaceEntriesImpl() /// /// A custom implementation for . /// -[Obsolete(WindowsRuntimeConstants.PrivateImplementationDetailObsoleteMessage, - DiagnosticId = WindowsRuntimeConstants.PrivateImplementationDetailObsoleteDiagnosticId, - UrlFormat = WindowsRuntimeConstants.CsWinRTDiagnosticsUrlFormat)] -[EditorBrowsable(EditorBrowsableState.Never)] -public sealed unsafe class WindowsRuntimePinnedArrayBufferComWrappersMarshallerAttribute : WindowsRuntimeComWrappersMarshallerAttribute +file sealed unsafe class WindowsRuntimePinnedArrayBufferComWrappersMarshallerAttribute : WindowsRuntimeComWrappersMarshallerAttribute { /// public override void* GetOrCreateComInterfaceForObject(object value) diff --git a/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs new file mode 100644 index 000000000..da1268cde --- /dev/null +++ b/src/WinRT.Runtime2/ABI/WindowsRuntime.InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using WindowsRuntime; +using WindowsRuntime.InteropServices; +using WindowsRuntime.InteropServices.Marshalling; +using static System.Runtime.InteropServices.ComWrappers; + +#pragma warning disable CS0723, IDE0008, IDE0046, IDE1006 + +[assembly: TypeMapAssociation( + source: typeof(WindowsRuntimePinnedMemoryBuffer), + proxy: typeof(ABI.WindowsRuntime.InteropServices.WindowsRuntimePinnedMemoryBuffer))] + +namespace ABI.WindowsRuntime.InteropServices; + +/// +/// ABI type for . +/// +[WindowsRuntimeClassName("Windows.Storage.Streams.IBuffer")] +[WindowsRuntimePinnedMemoryBufferComWrappersMarshaller] +file static class WindowsRuntimePinnedMemoryBuffer; + +/// +/// The set of values for . +/// +file struct WindowsRuntimePinnedMemoryBufferInterfaceEntries +{ + public ComInterfaceEntry IBuffer; + public ComInterfaceEntry IBufferByteAccess; + public ComInterfaceEntry IStringable; + public ComInterfaceEntry IWeakReferenceSource; + public ComInterfaceEntry IMarshal; + public ComInterfaceEntry IAgileObject; + public ComInterfaceEntry IInspectable; + public ComInterfaceEntry IUnknown; +} + +/// +/// The implementation of . +/// +file static class WindowsRuntimePinnedMemoryBufferInterfaceEntriesImpl +{ + /// + /// The value for . + /// + [FixedAddressValueType] + public static readonly WindowsRuntimePinnedMemoryBufferInterfaceEntries Entries; + + /// + /// Initializes . + /// + static WindowsRuntimePinnedMemoryBufferInterfaceEntriesImpl() + { + Entries.IBuffer.IID = WellKnownWindowsInterfaceIIDs.IID_IBuffer; + Entries.IBuffer.Vtable = Windows.Storage.Streams.IBufferImpl.Vtable; + Entries.IBufferByteAccess.IID = WellKnownWindowsInterfaceIIDs.IID_IBufferByteAccess; + Entries.IBufferByteAccess.Vtable = WindowsRuntimePinnedMemoryBufferByteAccessImpl.Vtable; + Entries.IStringable.IID = WellKnownWindowsInterfaceIIDs.IID_IStringable; + Entries.IStringable.Vtable = IStringableImpl.Vtable; + Entries.IWeakReferenceSource.IID = WellKnownWindowsInterfaceIIDs.IID_IWeakReferenceSource; + Entries.IWeakReferenceSource.Vtable = IWeakReferenceSourceImpl.Vtable; + Entries.IMarshal.IID = WellKnownWindowsInterfaceIIDs.IID_IMarshal; + Entries.IMarshal.Vtable = IMarshalImpl.RoBufferVtable; + Entries.IAgileObject.IID = WellKnownWindowsInterfaceIIDs.IID_IAgileObject; + Entries.IAgileObject.Vtable = IAgileObjectImpl.Vtable; + Entries.IInspectable.IID = WellKnownWindowsInterfaceIIDs.IID_IInspectable; + Entries.IInspectable.Vtable = IInspectableImpl.Vtable; + Entries.IUnknown.IID = WellKnownWindowsInterfaceIIDs.IID_IUnknown; + Entries.IUnknown.Vtable = IUnknownImpl.Vtable; + } +} + +/// +/// A custom implementation for . +/// +file sealed unsafe class WindowsRuntimePinnedMemoryBufferComWrappersMarshallerAttribute : WindowsRuntimeComWrappersMarshallerAttribute +{ + /// + public override void* GetOrCreateComInterfaceForObject(object value) + { + // No reference tracking is needed, see notes in the marshaller attribute for 'WindowsRuntimePinnedArrayBuffer' + return (void*)WindowsRuntimeComWrappers.Default.GetOrCreateComInterfaceForObject(value, CreateComInterfaceFlags.None); + } + + /// + public override ComInterfaceEntry* ComputeVtables(out int count) + { + count = sizeof(WindowsRuntimePinnedMemoryBufferInterfaceEntries) / sizeof(ComInterfaceEntry); + + return (ComInterfaceEntry*)Unsafe.AsPointer(in WindowsRuntimePinnedMemoryBufferInterfaceEntriesImpl.Entries); + } +} + +/// +/// The native implementation of IBufferByteAccess for . +/// +file static unsafe class WindowsRuntimePinnedMemoryBufferByteAccessImpl +{ + /// + /// The value for the implementation. + /// + [FixedAddressValueType] + private static readonly IBufferByteAccessVftbl Vftbl; + + /// + /// Initializes . + /// + static WindowsRuntimePinnedMemoryBufferByteAccessImpl() + { + *(IInspectableVftbl*)Unsafe.AsPointer(ref Vftbl) = *(IInspectableVftbl*)IInspectableImpl.Vtable; + + Vftbl.Buffer = &Buffer; + } + + /// + /// Gets a pointer to the implementation. + /// + public static nint Vtable + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (nint)Unsafe.AsPointer(in Vftbl); + } + + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + private static HRESULT Buffer(void* thisPtr, byte** value) + { + if (value is null) + { + return WellKnownErrorCodes.E_POINTER; + } + + try + { + var thisObject = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)thisPtr); + + *value = thisObject.Buffer(); + + return WellKnownErrorCodes.S_OK; + } + catch (Exception ex) + { + return RestrictedErrorInfoExceptionMarshaller.ConvertToUnmanaged(ex); + } + } +} diff --git a/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimeBufferHelpers.cs b/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimeBufferHelpers.cs index 35a1e503c..a259bca33 100644 --- a/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimeBufferHelpers.cs +++ b/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimeBufferHelpers.cs @@ -66,6 +66,14 @@ public static bool TryGetManagedSpanForCapacity(IBuffer buffer, out Span s return true; } + // Also handle pinned memory buffers (pointer-based, not array-backed) + if (buffer is WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer) + { + span = pinnedMemoryBuffer.GetSpanForCapacity(); + + return true; + } + span = default; return false; @@ -102,6 +110,43 @@ public static bool TryGetManagedArray(IBuffer buffer, [NotNullWhen(true)] out by return false; } + /// + /// Tries to get a pointer to the underlying data for the specified buffer, only if it is a known managed buffer implementation. + /// + /// The input instance. + /// The underlying data, if retrieved. + /// Whether could be retrieved. + public static unsafe bool TryGetManagedData(IBuffer buffer, out byte* data) + { + // If the buffer is backed by a managed array, get the data pointer + if (buffer is WindowsRuntimeExternalArrayBuffer externalArrayBuffer) + { + data = externalArrayBuffer.Buffer(); + + return true; + } + + // Same as above for pinned arrays as well + if (buffer is WindowsRuntimePinnedArrayBuffer pinnedArrayBuffer) + { + data = pinnedArrayBuffer.Buffer(); + + return true; + } + + // Also handle pinned memory buffers (pointer-based, not array-backed) + if (buffer is WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer) + { + data = pinnedMemoryBuffer.Buffer(); + + return true; + } + + data = null; + + return false; + } + /// /// Tries to get the underlying data for the specified buffer, only if backed by native memory. /// diff --git a/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs b/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs new file mode 100644 index 000000000..3b485243a --- /dev/null +++ b/src/WinRT.Runtime2/InteropServices/Buffers/WindowsRuntimePinnedMemoryBuffer.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using Windows.Storage.Streams; + +namespace WindowsRuntime.InteropServices; + +/// +/// Provides a managed implementation of the interface backed by a pinned +/// pointer to memory. This buffer does not own the underlying memory and can be invalidated to +/// prevent further access once the memory it points to is no longer guaranteed to be pinned. +/// +[WindowsRuntimeManagedOnlyType] +internal sealed unsafe class WindowsRuntimePinnedMemoryBuffer : IBuffer +{ + /// + /// The pointer to the pinned memory. + /// + private volatile byte* _data; + + /// + /// The number of bytes that can be read or written in the buffer. + /// + private int _length; + + /// + /// The capacity of the buffer. + /// + private readonly int _capacity; + + /// + /// Creates a instance with the specified parameters. + /// + /// The pointer to the pinned memory. + /// The number of bytes. + /// The maximum number of bytes the buffer can hold. + /// This constructor doesn't validate any of its parameters. + public WindowsRuntimePinnedMemoryBuffer(byte* data, int length, int capacity) + { + Debug.Assert(data is not null); + Debug.Assert(length >= 0); + Debug.Assert(capacity >= 0); + Debug.Assert(capacity >= length); + + _data = data; + _length = length; + _capacity = capacity; + } + + /// + public uint Capacity + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (uint)_capacity; + } + + /// + public uint Length + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (uint)_length; + set + { + ArgumentOutOfRangeException.ThrowIfBufferLengthExceedsCapacity(value, Capacity); + + _length = unchecked((int)value); + } + } + + /// + /// Thrown if the buffer has been invalidated. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public byte* Buffer() + { + byte* data = _data; + + InvalidOperationException.ThrowIfBufferIsInvalidated(data); + + return data; + } + + /// + /// Thrown if the buffer has been invalidated. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span GetSpanForCapacity() + { + byte* data = _data; + + InvalidOperationException.ThrowIfBufferIsInvalidated(data); + + return new(data, _capacity); + } + + /// + /// Invalidates the buffer, preventing any further access to the underlying memory. + /// + /// + /// + /// After calling this method, any attempt to call will throw + /// an . This is used to prevent use-after-free + /// scenarios when the buffer wraps memory that is only temporarily pinned (e.g. a span). + /// + /// + /// This type intentionally does not implement to perform this + /// invalidation. This is because would end up in the CCW interface + /// list for this implementation, which is not desirable since this type + /// is only meant to be used from the managed side in a specific, controlled context. + /// + /// + public void Invalidate() + { + _data = null; + } +} diff --git a/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Read.cs b/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Read.cs index 92e38ed29..f45d39ce7 100644 --- a/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Read.cs +++ b/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Read.cs @@ -2,7 +2,9 @@ // Licensed under the MIT License. using System; +using System.Buffers; using System.Diagnostics; +using System.Runtime.InteropServices; using System.Runtime.Versioning; using System.Threading; using System.Threading.Tasks; @@ -78,6 +80,11 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel // If already cancelled, stop early cancellationToken.ThrowIfCancellationRequested(); + if (count == 0) + { + return Task.FromResult(0); + } + // Helper to perform the actual asynchronous read operation async Task ReadCoreAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { @@ -124,6 +131,80 @@ async Task ReadCoreAsync(byte[] buffer, int offset, int count, Cancellation return ReadCoreAsync(buffer, offset, count, cancellationToken); } + /// + [SupportedOSPlatform("windows10.0.10240.0")] + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); + NotSupportedException.ThrowIfStreamCannotRead(_canRead); + + // If already cancelled, stop early + cancellationToken.ThrowIfCancellationRequested(); + + if (buffer.IsEmpty) + { + return new(0); + } + + // Fast path: if the memory is backed by an array, use the existing array-based overload directly + if (MemoryMarshal.TryGetArray((ReadOnlyMemory)buffer, out ArraySegment segment)) + { + return new(ReadAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + + // Helper to perform the actual asynchronous read operation with pinned memory + async ValueTask ReadPinnedMemoryAsync(Memory buffer, CancellationToken cancellationToken) + { + using MemoryHandle handle = buffer.Pin(); + + WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer; + + // An explicit unsafe block is needed here because the 'async unsafe' modifier is not supported + // by the language (CS4004: "Cannot await in an unsafe context"), so we scope the pointer access + // to just the buffer initialization, which is the only expression that requires unsafe context. + unsafe + { + pinnedMemoryBuffer = new((byte*)handle.Pointer, length: 0, capacity: buffer.Length); + } + + try + { + IInputStream windowsRuntimeStream = (IInputStream)EnsureNotDisposed(); + + IAsyncOperationWithProgress asyncReadOperation = windowsRuntimeStream.ReadAsync( + buffer: pinnedMemoryBuffer, + count: pinnedMemoryBuffer.Capacity, + options: InputStreamOptions.Partial); + + IBuffer? resultBuffer = await asyncReadOperation.AsTask(cancellationToken).ConfigureAwait(false); + + if (resultBuffer is null) + { + return 0; + } + + WindowsRuntimeIOHelpers.EnsureResultsInUserBuffer(pinnedMemoryBuffer, resultBuffer); + + Debug.Assert(resultBuffer.Length <= unchecked(int.MaxValue)); + + return unchecked((int)resultBuffer.Length); + } + catch (Exception exception) + { + WindowsRuntimeIOHelpers.GetExceptionDispatchInfo(exception).Throw(); + + return 0; + } + finally + { + pinnedMemoryBuffer.Invalidate(); + } + } + + // Slow path: pin the memory and use a pinned memory buffer for the async read operation + return ReadPinnedMemoryAsync(buffer, cancellationToken); + } + /// [SupportedOSPlatform("windows10.0.10240.0")] public override int Read(byte[] buffer, int offset, int count) @@ -135,6 +216,11 @@ public override int Read(byte[] buffer, int offset, int count) ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); NotSupportedException.ThrowIfStreamCannotRead(_canRead); + if (count == 0) + { + return 0; + } + // Helper to do a sync-over-async read operation StreamReadAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state, bool usedByBlockingWrapper) { @@ -194,6 +280,55 @@ StreamReadAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallb } /// + [SupportedOSPlatform("windows10.0.10240.0")] + public override unsafe int Read(Span buffer) + { + ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); + NotSupportedException.ThrowIfStreamCannotRead(_canRead); + + if (buffer.IsEmpty) + { + return 0; + } + + IInputStream windowsRuntimeStream = (IInputStream)EnsureNotDisposed(); + + // Pin the span so that it stays at the same address while the async I/O operation is in progress. + // We create a 'WindowsRuntimePinnedMemoryBuffer' wrapping the pinned pointer, then invalidate it + // in the 'finally' block to ensure no one can access the pointer after the span goes out of scope. + fixed (byte* pinnedData = buffer) + { + WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer = new(pinnedData, length: 0, capacity: buffer.Length); + + try + { + IAsyncOperationWithProgress asyncReadOperation = windowsRuntimeStream.ReadAsync( + buffer: pinnedMemoryBuffer, + count: unchecked((uint)buffer.Length), + options: InputStreamOptions.Partial); + + // See the large comment in the 'Read(byte[], int, int)' method about why we use + // a custom 'IAsyncResult' implementation instead of 'ReadAsync' + 'AsTask' here. + StreamReadAsyncResult asyncResult = new( + asyncReadOperation, + pinnedMemoryBuffer, + userCompletionCallback: null, + userAsyncStateInfo: null, + processCompletedOperationInCallback: false); + + int numberOfBytesRead = EndRead(asyncResult); + + return numberOfBytesRead; + } + finally + { + pinnedMemoryBuffer.Invalidate(); + } + } + } + + /// + [SupportedOSPlatform("windows10.0.10240.0")] public override int ReadByte() { byte result = 0; diff --git a/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Write.cs b/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Write.cs index c52e7ee3d..aa5101c0c 100644 --- a/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Write.cs +++ b/src/WinRT.Runtime2/InteropServices/Streams/Adapters/WindowsRuntimeManagedStreamAdapter.Implementation.Write.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using System; +using System.Buffers; +using System.Runtime.InteropServices; using System.Runtime.Versioning; using System.Threading; using System.Threading.Tasks; @@ -10,6 +12,8 @@ using Windows.Storage.Buffers; using Windows.Storage.Streams; +#pragma warning disable CS1573 + namespace WindowsRuntime.InteropServices; /// @@ -18,13 +22,6 @@ internal partial class WindowsRuntimeManagedStreamAdapter /// [SupportedOSPlatform("windows10.0.10240.0")] public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) - { - return BeginWrite(buffer, offset, count, callback, state, usedByBlockingWrapper: false); - } - - /// - [SupportedOSPlatform("windows10.0.10240.0")] - private StreamWriteAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state, bool usedByBlockingWrapper) { ArgumentNullException.ThrowIfNull(buffer); ArgumentOutOfRangeException.ThrowIfNegative(offset); @@ -33,20 +30,7 @@ private StreamWriteAsyncResult BeginWrite(byte[] buffer, int offset, int count, ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); NotSupportedException.ThrowIfStreamCannotWrite(_canWrite); - IOutputStream windowsRuntimeStream = (IOutputStream)EnsureNotDisposed(); - - IBuffer asyncWriteBuffer = buffer.AsBuffer(offset, count); - - // See the large comment in the 'BeginRead' method about why we are not using the - // 'WriteAsync' method, and instead using a custom implementation of 'IAsyncResult'. - IAsyncOperationWithProgress asyncWriteOperation = windowsRuntimeStream.WriteAsync(asyncWriteBuffer); - - // See additional notes in the 'Read' method about how CCW objects for this result are managed - return new StreamWriteAsyncResult( - asyncWriteOperation, - callback, - state, - processCompletedOperationInCallback: !usedByBlockingWrapper); + return BeginWrite(buffer, offset, count, callback, state, usedByBlockingWrapper: false); } /// @@ -101,6 +85,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati // If already cancelled, stop early cancellationToken.ThrowIfCancellationRequested(); + if (count == 0) + { + return Task.CompletedTask; + } + IOutputStream windowsRuntimeStream = (IOutputStream)EnsureNotDisposed(); IBuffer asyncWriteBuffer = buffer.AsBuffer(offset, count); @@ -110,11 +99,71 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return windowsRuntimeStream.WriteAsync(asyncWriteBuffer).AsTask(cancellationToken); } + /// + [SupportedOSPlatform("windows10.0.10240.0")] + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); + NotSupportedException.ThrowIfStreamCannotWrite(_canWrite); + + // If already cancelled, stop early + cancellationToken.ThrowIfCancellationRequested(); + + if (buffer.IsEmpty) + { + return default; + } + + // Fast path: if the memory is backed by an array, use the existing array-based overload directly + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + return new(WriteAsync(segment.Array!, segment.Offset, segment.Count, cancellationToken)); + } + + // Helper to perform the actual asynchronous write operation with pinned memory + async ValueTask WritePinnedMemoryAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + using MemoryHandle handle = buffer.Pin(); + + WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer; + + // See notes in 'ReadPinnedMemoryAsync' for why we need an 'unsafe' block here + unsafe + { + pinnedMemoryBuffer = new((byte*)handle.Pointer, length: buffer.Length, capacity: buffer.Length); + } + + try + { + IOutputStream windowsRuntimeStream = (IOutputStream)EnsureNotDisposed(); + + _ = await windowsRuntimeStream.WriteAsync(pinnedMemoryBuffer).AsTask(cancellationToken).ConfigureAwait(false); + } + finally + { + pinnedMemoryBuffer.Invalidate(); + } + } + + // Slow path: pin the memory and use a pinned memory buffer for the async write operation + return WritePinnedMemoryAsync(buffer, cancellationToken); + } + /// [SupportedOSPlatform("windows10.0.10240.0")] public override void Write(byte[] buffer, int offset, int count) { - // Arguments validation and disposal validation are done in 'BeginWrite' + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); + ArgumentException.ThrowIfInsufficientArrayElementsAfterOffset(buffer.Length, offset, count); + ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); + NotSupportedException.ThrowIfStreamCannotWrite(_canWrite); + + if (count == 0) + { + return; + } StreamWriteAsyncResult asyncResult = BeginWrite(buffer, offset, count, null, null, usedByBlockingWrapper: true); @@ -122,12 +171,55 @@ public override void Write(byte[] buffer, int offset, int count) } /// + [SupportedOSPlatform("windows10.0.10240.0")] public override void WriteByte(byte value) { // We don't need to call 'EnsureNotDisposed', see notes in 'ReadByte' Write(new ReadOnlySpan(in value)); } + /// + [SupportedOSPlatform("windows10.0.10240.0")] + public override unsafe void Write(ReadOnlySpan buffer) + { + ObjectDisposedException.ThrowIfStreamIsDisposed(_windowsRuntimeStream); + NotSupportedException.ThrowIfStreamCannotWrite(_canWrite); + + if (buffer.IsEmpty) + { + return; + } + + IOutputStream windowsRuntimeStream = (IOutputStream)EnsureNotDisposed(); + + // Pin the span so that it stays at the same address while the async I/O operation is in progress. + // We create a 'WindowsRuntimePinnedMemoryBuffer' wrapping the pinned pointer, then invalidate it + // in the 'finally' block to ensure no one can access the pointer after the span goes out of scope. + fixed (byte* pinnedData = buffer) + { + WindowsRuntimePinnedMemoryBuffer pinnedMemoryBuffer = new(pinnedData, length: buffer.Length, capacity: buffer.Length); + + try + { + // See the large comment in 'Read(byte[], int, int)' about why we use a custom 'IAsyncResult' + // implementation instead of 'WriteAsync' + 'AsTask' here (same deadlock concerns apply). + IAsyncOperationWithProgress asyncWriteOperation = windowsRuntimeStream.WriteAsync(pinnedMemoryBuffer); + + StreamWriteAsyncResult asyncResult = new( + asyncWriteOperation, + userCompletionCallback: null, + userAsyncStateInfo: null, + processCompletedOperationInCallback: false); + + EndWrite(asyncResult); + } + finally + { + pinnedMemoryBuffer.Invalidate(); + } + } + } + /// [SupportedOSPlatform("windows10.0.10240.0")] public override void Flush() @@ -188,4 +280,28 @@ public override Task FlushAsync(CancellationToken cancellationToken) return windowsRuntimeStream.FlushAsync().AsTask(cancellationToken); } + + /// + /// Indicates whether this method is being called by a method doing sync-over-async on the result. + [SupportedOSPlatform("windows10.0.10240.0")] + private StreamWriteAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state, bool usedByBlockingWrapper) + { + // This method doesn't do validation, to avoid repeating it in the 'Write' method that calls this one. + // It is only called by that method and by 'BeginWrite', so the validation there should be kept in sync. + + IOutputStream windowsRuntimeStream = (IOutputStream)EnsureNotDisposed(); + + IBuffer asyncWriteBuffer = buffer.AsBuffer(offset, count); + + // See the large comment in the 'BeginRead' method about why we are not using the + // 'WriteAsync' method, and instead using a custom implementation of 'IAsyncResult'. + IAsyncOperationWithProgress asyncWriteOperation = windowsRuntimeStream.WriteAsync(asyncWriteBuffer); + + // See additional notes in the 'Read' method about how CCW objects for this result are managed + return new StreamWriteAsyncResult( + asyncWriteOperation, + callback, + state, + processCompletedOperationInCallback: !usedByBlockingWrapper); + } } \ No newline at end of file diff --git a/src/WinRT.Runtime2/InteropServices/WindowsRuntimeBufferMarshal.cs b/src/WinRT.Runtime2/InteropServices/WindowsRuntimeBufferMarshal.cs index 83f3ccda5..ff0756aa5 100644 --- a/src/WinRT.Runtime2/InteropServices/WindowsRuntimeBufferMarshal.cs +++ b/src/WinRT.Runtime2/InteropServices/WindowsRuntimeBufferMarshal.cs @@ -39,19 +39,9 @@ public static unsafe bool TryGetDataUnsafe([NotNullWhen(true)] IBuffer? buffer, return true; } - // Also handle a managed instance of the external array buffer type from 'WinRT.Runtime.dll' - if (buffer is WindowsRuntimeExternalArrayBuffer externalArrayBuffer) - { - data = externalArrayBuffer.Buffer(); - - return true; - } - - // Same as above, but for pinned array buffers as well - if (buffer is WindowsRuntimePinnedArrayBuffer pinnedArrayBuffer) + // Also handle managed buffer implementations from 'WinRT.Runtime.dll' + if (WindowsRuntimeBufferHelpers.TryGetManagedData(buffer, out data)) { - data = pinnedArrayBuffer.Buffer(); - return true; } diff --git a/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionExtensions.cs b/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionExtensions.cs index aabe7a73b..7d84b588d 100644 --- a/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionExtensions.cs +++ b/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionExtensions.cs @@ -226,6 +226,26 @@ static void ThrowInvalidOperationException() } } + /// + /// Throws an if the buffer has been invalidated. + /// + /// The pointer to the buffer data. + /// Thrown if is . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [StackTraceHidden] + public static unsafe void ThrowIfBufferIsInvalidated(byte* data) + { + [DoesNotReturn] + [StackTraceHidden] + static void ThrowInvalidOperationException() + => throw new InvalidOperationException(WindowsRuntimeExceptionMessages.InvalidOperation_CannotAccessInvalidatedBuffer); + + if (data is null) + { + ThrowInvalidOperationException(); + } + } + /// /// Creates an indicating that the method cannot be called in the current state. /// diff --git a/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionMessages.cs b/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionMessages.cs index 14e1125b4..653c36dee 100644 --- a/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionMessages.cs +++ b/src/WinRT.Runtime2/Properties/WindowsRuntimeExceptionMessages.cs @@ -88,6 +88,8 @@ internal static class WindowsRuntimeExceptionMessages public const string ArgumentOutOfRange_IO_CannotSeekToNegativePosition = "Cannot seek to an absolute stream position that is negative."; + public const string InvalidOperation_CannotAccessInvalidatedBuffer = "Cannot access the underlying data of this buffer because it has been invalidated."; + public const string InvalidOperation_CannotCallThisMethodInCurrentState = "The state of this object does not permit invoking this method."; public const string InvalidOperation_CannotChangeBufferSizeOfStreamAdapter = "Cannot convert the specified Windows Runtime stream to a managed System.IO.Stream object with the specified buffer size because this Windows Runtime stream has been previously converted to a managed Stream object with a different buffer size. Ensure that the 'bufferSize' argument matches the existing buffer or use the '{0}'-overload without the 'bufferSize' argument to convert the specified Windows Runtime stream to a Stream object with the same buffer size as previously.";