// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System; using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; namespace Microsoft.AspNetCore.SystemWebAdapters.Features; internal class HttpResponseAdapterFeature : Stream, IHttpResponseBodyFeature, IHttpResponseBufferingFeature, IHttpResponseEndFeature, IHttpResponseContentFeature { private enum StreamState { NotStarted, Buffering, NotBuffering, Complete, } private readonly IHttpResponseBodyFeature _responseBodyFeature; private FileBufferingWriteStream? _bufferedStream; private PipeWriter? _pipeWriter; private StreamState _state; private Func<FileBufferingWriteStream>? _factory; private bool _suppressContent; private Stream? _filter; public HttpResponseAdapterFeature(IHttpResponseBodyFeature httpResponseBody) { _responseBodyFeature = httpResponseBody; _state = StreamState.NotStarted; } Task IHttpResponseBodyFeature.CompleteAsync() => CompleteAsync(); public void DisableBuffering() { _responseBodyFeature.DisableBuffering(); _state = StreamState.NotBuffering; // If anything is already buffered, we'll use a custom pipe that will // clear out the buffer the next time flush is called since this method // is not async if (_bufferedStream is { }) { _pipeWriter = new FlushingBufferedPipeWriter(this, _responseBodyFeature.Writer); } else { _pipeWriter = _responseBodyFeature.Writer; } } void IHttpResponseBufferingFeature.EnableBuffering(int? memoryThreshold, long? bufferLimit) { if (_state == StreamState.Buffering) { return; } else if (_state == StreamState.NotStarted) { Debug.Assert(_bufferedStream is null); _state = StreamState.Buffering; _factory = () => new FileBufferingWriteStream(memoryThreshold ?? PreBufferRequestStreamAttribute.DefaultBufferThreshold, bufferLimit); } else { throw new InvalidOperationException("Cannot enable buffering if writing has begun"); } } Task IHttpResponseBodyFeature.StartAsync(CancellationToken cancellationToken) { if (_state == StreamState.NotStarted) { _state = StreamState.NotBuffering; } return _responseBodyFeature.StartAsync(cancellationToken); } bool IHttpResponseBufferingFeature.IsEnabled { get { return _state != StreamState.NotBuffering && _state != StreamState.NotStarted; } } private async ValueTask FlushInternalAsync() { if (_pipeWriter is { }) { await _pipeWriter.FlushAsync(); } if (_state is StreamState.Buffering) { await DrainStreamAsync(default); } } private async ValueTask DrainStreamAsync(CancellationToken token) { if (_bufferedStream is null) { return; } if (!SuppressContent) { if (_filter is { } filter) { await _bufferedStream.DrainBufferAsync(filter, token); await filter.DisposeAsync(); _filter = null; } else { await _bufferedStream.DrainBufferAsync(_responseBodyFeature.Stream, token); } } await _bufferedStream.DisposeAsync(); _bufferedStream = null; } Stream IHttpResponseBodyFeature.Stream => this; PipeWriter IHttpResponseBodyFeature.Writer { get { if (_pipeWriter is null) { _pipeWriter = PipeWriter.Create(this, new StreamPipeWriterOptions(leaveOpen: true)); if (_state is StreamState.Complete) { _pipeWriter.Complete(); } } return _pipeWriter; } } public bool SuppressContent { get => _suppressContent; set { if (value) { VerifyBuffering(); } _suppressContent = value; } } Task IHttpResponseEndFeature.EndAsync() => CompleteAsync(); bool IHttpResponseEndFeature.IsEnded => _state == StreamState.Complete; void IHttpResponseContentFeature.ClearContent() { if (CurrentStream is { CanSeek: true } body) { body.SetLength(0); return; } VerifyBuffering(); _bufferedStream?.Dispose(); _bufferedStream = null; } [MemberNotNull(nameof(_factory))] private void VerifyBuffering() { if (_state != StreamState.Buffering) { throw new InvalidOperationException("Response buffering is required"); } Debug.Assert(_factory is not null); } ValueTask IHttpResponseBufferingFeature.FlushAsync() => FlushInternalAsync(); private Stream CurrentStream { get { if (_state == StreamState.Buffering) { VerifyBuffering(); return _bufferedStream ??= _factory(); } else { if (_state != StreamState.Complete) { _state = StreamState.NotBuffering; } return _responseBodyFeature.Stream; } } } public override async ValueTask DisposeAsync() { if (_bufferedStream is not null) { await _bufferedStream.DisposeAsync(); } await base.DisposeAsync(); } public override bool CanRead => false; public override bool CanSeek => false; public override bool CanWrite => true; public override long Length => CurrentStream.Length; public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } [AllowNull] Stream IHttpResponseBufferingFeature.Filter { get { VerifyBuffering(); return _filter ?? _responseBodyFeature.Stream; } set { VerifyBuffering(); _filter = value; } } private async Task CompleteAsync() { if (_state == StreamState.Complete) { return; } await FlushInternalAsync(); _state = StreamState.Complete; if (_pipeWriter is { }) { await _pipeWriter.CompleteAsync(); } await _responseBodyFeature.CompleteAsync(); } public override void Flush() => CurrentStream.Flush(); public override Task FlushAsync(CancellationToken cancellationToken) => CurrentStream.FlushAsync(cancellationToken); public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default) => SendFileFallback.SendFileAsync(CurrentStream, path, offset, count, cancellationToken); public override void SetLength(long value) => throw new NotSupportedException(); public override void Write(byte[] buffer, int offset, int count) => CurrentStream.Write(buffer, offset, count); public override void Write(ReadOnlySpan<byte> buffer) => CurrentStream.Write(buffer); public override void WriteByte(byte value) => CurrentStream.WriteByte(value); public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) => CurrentStream.WriteAsync(buffer, cancellationToken); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => CurrentStream.WriteAsync(buffer, offset, count, cancellationToken); /// <summary> /// A <see cref="PipeWriter"/> that can flush any existing buffered items before writing next sequence of bytes /// Intended to be used if <see cref="IHttpResponseBodyFeature.DisableBuffering"/> is called and data has been buffered /// to ensure that the final output will be ordered correctly (since we can't asynchronously write the data in that call). /// </summary> /// <remarks> /// Calls to <see cref="Advance(int)"/>, <see cref="GetSpan(int)"/>, <see cref="GetMemory(int)"/> must be called /// in a group without calling <see cref="FlushAsync(CancellationToken)"/>. If not, then the call to <see cref="Advance(int)"/> /// will potentially advance the inner pipe rather than the buffer. /// </remarks> private sealed class FlushingBufferedPipeWriter : PipeWriter { private readonly PipeWriter _other; private HttpResponseAdapterFeature? _feature; private ArrayBufferWriter<byte>? _buffer; public FlushingBufferedPipeWriter(HttpResponseAdapterFeature feature, PipeWriter other) { _feature = feature; _other = other; } public override void CancelPendingFlush() => _other.CancelPendingFlush(); public override void Complete(Exception? exception = null) => _other.Complete(exception); public override async ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default) { await FlushExistingDataAsync(cancellationToken); return await _other.FlushAsync(cancellationToken); } private async ValueTask FlushExistingDataAsync(CancellationToken cancellationToken) { if (_feature is { }) { await _feature.DrainStreamAsync(cancellationToken); _feature = null; } if (_buffer is { }) { await _other.WriteAsync(_buffer.WrittenMemory, cancellationToken); _buffer = null; } } public bool IsBuffered => _feature is { }; public override void Advance(int bytes) { if (_buffer is { }) { _buffer.Advance(bytes); } else { _other.Advance(bytes); } } public override Memory<byte> GetMemory(int sizeHint = 0) { if (IsBuffered) { return (_buffer ??= new()).GetMemory(sizeHint); } else { return _other.GetMemory(sizeHint); } } public override Span<byte> GetSpan(int sizeHint = 0) { if (IsBuffered) { return (_buffer ??= new()).GetSpan(sizeHint); } else { return _other.GetSpan(sizeHint); } } } }