Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MSL] Invalid MSL when casting 16-bit types #2046

Closed
gmitrano-unity opened this issue Oct 26, 2022 · 3 comments · Fixed by #2049
Closed

[MSL] Invalid MSL when casting 16-bit types #2046

gmitrano-unity opened this issue Oct 26, 2022 · 3 comments · Fixed by #2049
Labels
bug Feature which should work in SPIRV-Cross does not for some reason. in progress Issue is being actively worked on

Comments

@gmitrano-unity
Copy link

I seem to be running into some problems with the MSL generated by SPIRV-Cross when 16-bit types are used with casts.

Simple Repro:

HLSL:

StructuredBuffer<uint16_t>  g_DataUShort;
RWStructuredBuffer<float16_t> g_DataHalf;

[numthreads(64, 1, 1)]
void Msl16Test(uint3 tid : SV_DispatchThreadID)
{
    uint idx = tid.x;
    g_DataHalf[idx] = asfloat16(g_DataUShort[idx] - g_DataUShort[idx + 1]);
}

SPIR-V (Compiled with dxc -spirv -enable-16bit-types -T cs_6_2 -E Msl16Test msl16test.hlsl):

; SPIR-V
; Version: 1.0
; Generator: Google spiregg; 0
; Bound: 34
; Schema: 0
               OpCapability Shader
               OpCapability StorageBuffer16BitAccess
               OpCapability Int16
               OpCapability Float16
               OpExtension "SPV_KHR_16bit_storage"
               OpMemoryModel Logical GLSL450
               OpEntryPoint GLCompute %Msl16Test "Msl16Test" %gl_GlobalInvocationID
               OpExecutionMode %Msl16Test LocalSize 64 1 1
               OpSource HLSL 620
               OpName %type_StructuredBuffer_ushort "type.StructuredBuffer.ushort"
               OpName %g_DataUShort "g_DataUShort"
               OpName %type_RWStructuredBuffer_half "type.RWStructuredBuffer.half"
               OpName %g_DataHalf "g_DataHalf"
               OpName %Msl16Test "Msl16Test"
               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
               OpDecorate %g_DataUShort DescriptorSet 0
               OpDecorate %g_DataUShort Binding 0
               OpDecorate %g_DataHalf DescriptorSet 0
               OpDecorate %g_DataHalf Binding 1
               OpDecorate %_runtimearr_ushort ArrayStride 2
               OpMemberDecorate %type_StructuredBuffer_ushort 0 Offset 0
               OpMemberDecorate %type_StructuredBuffer_ushort 0 NonWritable
               OpDecorate %type_StructuredBuffer_ushort BufferBlock
               OpDecorate %_runtimearr_half ArrayStride 2
               OpMemberDecorate %type_RWStructuredBuffer_half 0 Offset 0
               OpDecorate %type_RWStructuredBuffer_half BufferBlock
        %int = OpTypeInt 32 1
      %int_0 = OpConstant %int 0
       %uint = OpTypeInt 32 0
     %uint_1 = OpConstant %uint 1
     %ushort = OpTypeInt 16 0
%_runtimearr_ushort = OpTypeRuntimeArray %ushort
%type_StructuredBuffer_ushort = OpTypeStruct %_runtimearr_ushort
%_ptr_Uniform_type_StructuredBuffer_ushort = OpTypePointer Uniform %type_StructuredBuffer_ushort
       %half = OpTypeFloat 16
%_runtimearr_half = OpTypeRuntimeArray %half
%type_RWStructuredBuffer_half = OpTypeStruct %_runtimearr_half
%_ptr_Uniform_type_RWStructuredBuffer_half = OpTypePointer Uniform %type_RWStructuredBuffer_half
     %v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
       %void = OpTypeVoid
         %20 = OpTypeFunction %void
%_ptr_Uniform_ushort = OpTypePointer Uniform %ushort
%_ptr_Uniform_half = OpTypePointer Uniform %half
%g_DataUShort = OpVariable %_ptr_Uniform_type_StructuredBuffer_ushort Uniform
 %g_DataHalf = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_half Uniform
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
  %Msl16Test = OpFunction %void None %20
         %23 = OpLabel
         %24 = OpLoad %v3uint %gl_GlobalInvocationID
         %25 = OpCompositeExtract %uint %24 0
         %26 = OpAccessChain %_ptr_Uniform_ushort %g_DataUShort %int_0 %25
         %27 = OpLoad %ushort %26
         %28 = OpIAdd %uint %25 %uint_1
         %29 = OpAccessChain %_ptr_Uniform_ushort %g_DataUShort %int_0 %28
         %30 = OpLoad %ushort %29
         %31 = OpISub %ushort %27 %30
         %32 = OpBitcast %half %31
         %33 = OpAccessChain %_ptr_Uniform_half %g_DataHalf %int_0 %25
               OpStore %33 %32
               OpReturn
               OpFunctionEnd

MSL (Compiled with spirv-cross --msl msl16test.spv):

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct type_StructuredBuffer_ushort
{
    ushort _m0[1];
};

struct type_RWStructuredBuffer_half
{
    half _m0[1];
};

kernel void Msl16Test(const device type_StructuredBuffer_ushort& g_DataUint [[buffer(0)]], device type_RWStructuredBuffer_half& g_DataFloat [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
    g_DataFloat._m0[gl_GlobalInvocationID.x] = as_type<half>(g_DataUint._m0[gl_GlobalInvocationID.x] - g_DataUint._m0[gl_GlobalInvocationID.x + 1u]);
}

Metal Compilation (metal msl16test.metal):

msl16test.metal:18:48: error: as_type cast from 'int' to 'half' is not allowed
    g_DataFloat._m0[gl_GlobalInvocationID.x] = as_type<half>(g_DataUint._m0[gl_GlobalInvocationID.x] - g_DataUint._m0[gl_GlobalInvocationID.x + 1u]);
                                               ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.

Based on the error message, it looks like the ushort types from the SPIR-V are being promoted to full width integers during the subtraction operation in MSL. I believe this happens for other operations like bit shifts as well.

I hacked the code a bit to get it to insert an extra integer cast in spirv_glsl.cpp:6129 (CompilerGLSL::emit_binary_op_cast) like this:

void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
                                       const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type)
{
	string cast_op0, cast_op1;
	auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type);
	auto &out_type = get<SPIRType>(result_type);

	// We might have casted away from the result type, so bitcast again.
	// For example, arithmetic right shift with uint inputs.
	// Special case boolean outputs since relational opcodes output booleans instead of int/uint.
	string expr;
	if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean)
	{
		expected_type.basetype = input_type;
		expr = bitcast_glsl_op(out_type, expected_type);
		expr += '(';
		expr += join(cast_op0, " ", op, " ", cast_op1);
		expr += ')';
	}
	// HACK: Add integer cast whenever the output type is narrow to avoid integer promotion issues
	else if (type_is_integral(out_type) && (out_type.width < 32))
	{
		expr = type_to_glsl(out_type);
		expr += '(';
		expr += join(cast_op0, " ", op, " ", cast_op1);
		expr += ')';
	}
	else
		expr += join(cast_op0, " ", op, " ", cast_op1);

	emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1));
	inherit_expression_dependencies(result_id, op0);
	inherit_expression_dependencies(result_id, op1);
}

This gives me the following MSL which compiles successfully:

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct type_StructuredBuffer_ushort
{
    ushort _m0[1];
};

struct type_RWStructuredBuffer_half
{
    half _m0[1];
};

kernel void Msl16Test(const device type_StructuredBuffer_ushort& g_DataUint [[buffer(0)]], device type_RWStructuredBuffer_half& g_DataFloat [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
    g_DataFloat._m0[gl_GlobalInvocationID.x] = as_type<half>(ushort(g_DataUint._m0[gl_GlobalInvocationID.x] - g_DataUint._m0[gl_GlobalInvocationID.x + 1u]));
}

I created an issue rather than a PR since I know this change isn't robust.
Is there a reasonable way to inject some logic like this that only affects the MSL backend?

I see that there's already some references to this issue in the MSL specific code here (spirv_msl.cpp:15054):

string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
{
	if (out_type.basetype == in_type.basetype)
		return "";

	assert(out_type.basetype != SPIRType::Boolean);
	assert(in_type.basetype != SPIRType::Boolean);

	bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type) && (out_type.vecsize == in_type.vecsize);
	bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);

	// Bitcasting can only be used between types of the same overall size.
	// And always formally cast between integers, because it's trivial, and also
	// because Metal can internally cast the results of some integer ops to a larger
	// size (eg. short shift right becomes int), which means chaining integer ops
	// together may introduce size variations that SPIR-V doesn't know about.
	if (same_size_cast && !integral_cast)
		return "as_type<" + type_to_glsl(out_type) + ">";
	else
		return type_to_glsl(out_type);
}

What would be the best way to solve this? I imagine this will require some new hooks in CompilerMSL? Or maybe there's another way?

Thanks!

@HansKristian-Work HansKristian-Work added the needs triage Needs to be reproduced before it can become a different issue type. label Oct 27, 2022
@HansKristian-Work
Copy link
Contributor

This is annoying. I think the only correct way to do this is to cast after an arithmetic operation that has implicit promotion. For example. If we consider (ushort(3) - ushort(4)) >> ushort(3) we expect to keep the value as ushort, but we might end up seeing:

ushort(3) - ushort(4) implicitly promotes to int, making it int(-1), a shift will then be considered arithmetic not logical unless we explicitly cast the result.

@gmitrano-unity
Copy link
Author

I did some more testing and found that the integer promotion behavior doesn't seem to affect vector types.

As in, this gets promoted:

g_DataHalf._m0[gl_GlobalInvocationID.x] = as_type<half>(g_DataUShort._m0[gl_GlobalInvocationID.x] - g_DataUShort._m0[gl_GlobalInvocationID.x + 1u]);

but this does not:

g_DataHalf._m0[gl_GlobalInvocationID.x] = as_type<half2>(ushort2(g_DataUShort._m0[gl_GlobalInvocationID.x], g_DataUShort._m0[gl_GlobalInvocationID.x + 1]) - ushort2(g_DataUShort._m0[gl_GlobalInvocationID.x + 2], g_DataUShort._m0[gl_GlobalInvocationID.x + 3]));

Since emit_binary_op_cast runs after every arithmetic operation already, it seems like we could just add a snippet like this inside it:

// If integer promotions are implicit, then we need to explicitly cast back to the original type after any binary
// operation in order to prevent the types from getting out of sync.
if (is_integer_promotion_implicit() && type_is_integral(out_type) && (out_type.width < 32) && (out_type.vecsize == 1))
{
    expr = join(type_to_glsl(out_type), "(", expr, ")");
}

and that'd be enough to fix the issue? (Assuming is_integer_promotion_implicit() only returns true on MSL, and some fixup is done inside optimize_read_modify_write() in order to pass the shader regression tests)

Would that be an acceptable approach?

@HansKristian-Work HansKristian-Work added bug Feature which should work in SPIRV-Cross does not for some reason. in progress Issue is being actively worked on and removed needs triage Needs to be reproduced before it can become a different issue type. labels Oct 31, 2022
@gmitrano-unity
Copy link
Author

Thank you for fixing this! I can confirm that #2049 fixes my local shaders as well. 🙂

stuartcarnie added a commit to stuartcarnie/godot that referenced this issue Dec 20, 2024
tGautot pushed a commit to tGautot/godot that referenced this issue Feb 5, 2025
tGautot pushed a commit to tGautot/godot that referenced this issue Feb 5, 2025
WhalesState pushed a commit to WhalesState/blazium that referenced this issue Mar 22, 2025
WhalesState pushed a commit to WhalesState/blazium that referenced this issue Mar 23, 2025
WhalesState pushed a commit to WhalesState/blazium that referenced this issue Mar 24, 2025
WhalesState pushed a commit to WhalesState/blazium that referenced this issue Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Feature which should work in SPIRV-Cross does not for some reason. in progress Issue is being actively worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants