Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update HTTPS server variable in ProxyHeaderModule #123

Merged
merged 1 commit into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@ namespace Microsoft.AspNetCore.SystemWebAdapters;

internal readonly struct ForwardedHost
{
private readonly int? _port;

public ForwardedHost(string host, string? proto)
{
var hostString = HostString.FromUriComponent(host);

IsSecure = string.Equals("https", proto, StringComparison.OrdinalIgnoreCase);
ServerName = hostString.Host;
Port = hostString.Port is int p ? p : GetDefaultPort(proto);
_port = hostString.Port;
}

private static int GetDefaultPort(string? proto)
=> string.Equals("https", proto, StringComparison.OrdinalIgnoreCase) ? 443 : 80;
private int DefaultPort => IsSecure ? 443 : 80;

public bool IsSecure { get; }

public string ServerName { get; }

public int Port { get; }
public int Port => _port is int port ? port : DefaultPort;
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@

namespace Microsoft.AspNetCore.SystemWebAdapters;

/// <summary>
/// Updates server and request variables based on proxy headers. See https://docs.microsoft.com/en-us/iis/web-dev-reference/server-variables for reference on what server variables should be used.
/// </summary>
internal class ProxyHeaderModule : IHttpModule
{
private const string Host = "Host";
private const string ServerHttps = "HTTPS";
private const string ServerName = "SERVER_NAME";
private const string ServerPort = "SERVER_PORT";
private const string ServerProtocol = "SERVER_PROTOCOL";
private const string ForwardedProto = "x-forwarded-proto";
private const string ForwardedHost = "x-forwarded-host";
private const string On = "ON";
private const string Off = "OFF";

private readonly ProxyOptions _options;
private readonly IOptions<ProxyOptions> _options;

public ProxyHeaderModule(IOptions<ProxyOptions> options)
{
_options = options?.Value ?? throw new ArgumentNullException(nameof(options));
_options = options ?? throw new ArgumentNullException(nameof(options));
}

public void Dispose()
Expand All @@ -31,7 +36,9 @@ public void Dispose()

public void Init(HttpApplication context)
{
if (_options.UseForwardedHeaders)
var options = _options.Value;

if (options.UseForwardedHeaders)
{
context.BeginRequest += (s, e) =>
{
Expand All @@ -41,15 +48,12 @@ public void Init(HttpApplication context)
}
else
{
if (_options.ServerName is null)
{
throw new InvalidOperationException("Server name must be set for proxy options.");
}
var values = new ServerValues(options);

context.BeginRequest += (s, e) =>
{
var request = ((HttpApplication)s).Context.Request;
UseOptions(request.Headers, request.ServerVariables);
UseOptions(values, request.Headers, request.ServerVariables);
};
}
}
Expand All @@ -64,33 +68,30 @@ public void UseHeaders(NameValueCollection requestHeaders, NameValueCollection s
{
if (requestHeaders[Host] is { } originalHost)
{
requestHeaders[_options.OriginalHostHeaderName] = originalHost;
requestHeaders[_options.Value.OriginalHostHeaderName] = originalHost;
}

var value = new ForwardedHost(host, proto);

serverVariables.Set(ServerName, value.ServerName);
serverVariables.Set(ServerPort, value.Port.ToString(CultureInfo.InvariantCulture));
serverVariables.Set(ServerHttps, value.IsSecure ? On : Off);

requestHeaders[Host] = host;
}

if (proto is { })
{
serverVariables.Set(ServerProtocol, proto);
}
}

private void UseOptions(NameValueCollection requestHeaders, NameValueCollection serverVariables)
private static void UseOptions(ServerValues values, NameValueCollection requestHeaders, NameValueCollection serverVariables)
{
UseForwardedFor(requestHeaders, serverVariables);

serverVariables.Set(ServerName, _options.ServerName);
serverVariables.Set(ServerPort, _options.ServerPortString);
serverVariables.Set(ServerProtocol, _options.Scheme);
requestHeaders[Host] = _options.ServerHostString;
serverVariables.Set(ServerName, values.Name);
serverVariables.Set(ServerPort, values.Port);
serverVariables.Set(ServerHttps, values.Https);
requestHeaders[Host] = values.Host;
}


private static void UseForwardedFor(NameValueCollection requestHeaders, NameValueCollection serverVariables)
{
if (requestHeaders["x-forwarded-for"] is { } remote)
Expand All @@ -99,4 +100,28 @@ private static void UseForwardedFor(NameValueCollection requestHeaders, NameValu
serverVariables.Set("REMOTE_HOST", remote);
}
}

private class ServerValues
{
public ServerValues(ProxyOptions options)
{
if (options.ServerName is null)
{
throw new InvalidOperationException("Server name must be set for proxy options.");
}

Name = options.ServerName;
Port = options.ServerPort.ToString(CultureInfo.InvariantCulture);
Https = string.Equals("https", options.Scheme, StringComparison.OrdinalIgnoreCase) ? On : Off;
Host = $"{Name}:{Port}";
}

public string Name { get; }

public string Port { get; }

public string Https { get; }

public string Host { get; }
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Globalization;

namespace Microsoft.AspNetCore.SystemWebAdapters;

public class ProxyOptions
{
private string? _port;
private string? _serverHostString;

/// <summary>
/// Gets or sets whether the X-Forwarded-* headers should be used for incoming requests.
/// </summary>
Expand All @@ -31,8 +26,4 @@ public class ProxyOptions
/// Gets or sets the scheme.
/// </summary>
public string Scheme { get; set; } = "https";

internal string ServerPortString => _port ??= ServerPort.ToString(CultureInfo.InvariantCulture);

internal string ServerHostString => _serverHostString ??= $"{ServerName}:{ServerPortString}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class ProxyHeaderModuleTests
private const string RemoteHost = "REMOTE_HOST";
private const string ServerName = "SERVER_NAME";
private const string ServerPort = "SERVER_PORT";
private const string ServerHttps = "HTTPS";
private const string On = "ON";
private const string Off = "OFF";

[Fact]
public void NoHeaderChange()
Expand Down Expand Up @@ -57,6 +60,7 @@ public void HostWithPortNoProto()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("localhost:81", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(Off, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -81,6 +85,7 @@ public void HostWithNoPortNoProto()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("localhost", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(Off, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -106,6 +111,7 @@ public void HostWithNoPortHttp()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("localhost", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(Off, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -132,6 +138,7 @@ public void HostAlreadySet()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("localhost", requestHeaders[Host]);
Assert.Equal("localhost2:90", requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(Off, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -157,6 +164,7 @@ public void HostWithNoPortHttps()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("localhost", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(On, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -182,6 +190,7 @@ public void IPv6NoPort()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("::1", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(On, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -207,6 +216,7 @@ public void IPv6WithPort()
Assert.Null(serverVariables[RemoteHost]);
Assert.Equal("[::1]:81", requestHeaders[Host]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Assert.Equal(On, serverVariables[ServerHttps]);
}

[Fact]
Expand All @@ -230,6 +240,7 @@ public void ForwardedForSet()
Assert.Null(requestHeaders[Host]);
Assert.Null(serverVariables[ServerName]);
Assert.Null(serverVariables[ServerPort]);
Assert.Null(serverVariables[ServerHttps]);
Assert.Equal(ForwardedForValue, serverVariables[RemoteAddress]);
Assert.Equal(ForwardedForValue, serverVariables[RemoteHost]);
Assert.Null(requestHeaders[options.OriginalHostHeaderName]);
Expand Down