Fix vote and shuffle shader instructions on AMD GPUs (#5540)
* Move shuffle handling out of the backend to a transform pass * Handle subgroup sizes higher than 32 * Stop using the subgroup size control extension * Make GenerateShuffleFunction static * Shader cache version bump
This commit is contained in:
parent
64079c034c
commit
6ed613a6e6
35 changed files with 445 additions and 265 deletions
|
@ -25,6 +25,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
|
|||
{
|
||||
context.AppendLine("#extension GL_KHR_shader_subgroup_basic : enable");
|
||||
context.AppendLine("#extension GL_KHR_shader_subgroup_ballot : enable");
|
||||
context.AppendLine("#extension GL_KHR_shader_subgroup_shuffle : enable");
|
||||
}
|
||||
|
||||
context.AppendLine("#extension GL_ARB_shader_group_vote : enable");
|
||||
|
@ -201,26 +202,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
|
|||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/MultiplyHighU32.glsl");
|
||||
}
|
||||
|
||||
if ((info.HelperFunctionsMask & HelperFunctionsMask.Shuffle) != 0)
|
||||
{
|
||||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl");
|
||||
}
|
||||
|
||||
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleDown) != 0)
|
||||
{
|
||||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl");
|
||||
}
|
||||
|
||||
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleUp) != 0)
|
||||
{
|
||||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl");
|
||||
}
|
||||
|
||||
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleXor) != 0)
|
||||
{
|
||||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl");
|
||||
}
|
||||
|
||||
if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0)
|
||||
{
|
||||
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/SwizzleAdd.glsl");
|
||||
|
|
|
@ -5,10 +5,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
|
|||
public static string MultiplyHighS32 = "Helper_MultiplyHighS32";
|
||||
public static string MultiplyHighU32 = "Helper_MultiplyHighU32";
|
||||
|
||||
public static string Shuffle = "Helper_Shuffle";
|
||||
public static string ShuffleDown = "Helper_ShuffleDown";
|
||||
public static string ShuffleUp = "Helper_ShuffleUp";
|
||||
public static string ShuffleXor = "Helper_ShuffleXor";
|
||||
public static string SwizzleAdd = "Helper_SwizzleAdd";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
float Helper_Shuffle(float x, uint index, uint mask, out bool valid)
|
||||
{
|
||||
uint clamp = mask & 0x1fu;
|
||||
uint segMask = (mask >> 8) & 0x1fu;
|
||||
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
|
||||
uint maxThreadId = minThreadId | (clamp & ~segMask);
|
||||
uint srcThreadId = (index & ~segMask) | minThreadId;
|
||||
valid = srcThreadId <= maxThreadId;
|
||||
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
|
||||
return valid ? v : x;
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
float Helper_ShuffleDown(float x, uint index, uint mask, out bool valid)
|
||||
{
|
||||
uint clamp = mask & 0x1fu;
|
||||
uint segMask = (mask >> 8) & 0x1fu;
|
||||
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
|
||||
uint maxThreadId = minThreadId | (clamp & ~segMask);
|
||||
uint srcThreadId = $SUBGROUP_INVOCATION$ + index;
|
||||
valid = srcThreadId <= maxThreadId;
|
||||
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
|
||||
return valid ? v : x;
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
float Helper_ShuffleUp(float x, uint index, uint mask, out bool valid)
|
||||
{
|
||||
uint segMask = (mask >> 8) & 0x1fu;
|
||||
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
|
||||
uint srcThreadId = $SUBGROUP_INVOCATION$ - index;
|
||||
valid = int(srcThreadId) >= int(minThreadId);
|
||||
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
|
||||
return valid ? v : x;
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
float Helper_ShuffleXor(float x, uint index, uint mask, out bool valid)
|
||||
{
|
||||
uint clamp = mask & 0x1fu;
|
||||
uint segMask = (mask >> 8) & 0x1fu;
|
||||
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
|
||||
uint maxThreadId = minThreadId | (clamp & ~segMask);
|
||||
uint srcThreadId = $SUBGROUP_INVOCATION$ ^ index;
|
||||
valid = srcThreadId <= maxThreadId;
|
||||
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
|
||||
return valid ? v : x;
|
||||
}
|
|
@ -9,6 +9,7 @@ using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenFSI;
|
|||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenShuffle;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenVector;
|
||||
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
|
||||
|
||||
|
@ -174,6 +175,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
|
|||
case Instruction.PackHalf2x16:
|
||||
return PackHalf2x16(context, operation);
|
||||
|
||||
case Instruction.Shuffle:
|
||||
return Shuffle(context, operation);
|
||||
|
||||
case Instruction.Store:
|
||||
return Store(context, operation);
|
||||
|
||||
|
|
|
@ -13,14 +13,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
|
|||
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
|
||||
|
||||
string arg = GetSoureExpr(context, operation.GetSource(0), dstType);
|
||||
char component = "xyzw"[operation.Index];
|
||||
|
||||
if (context.HostCapabilities.SupportsShaderBallot)
|
||||
{
|
||||
return $"unpackUint2x32(ballotARB({arg})).x";
|
||||
return $"unpackUint2x32(ballotARB({arg})).{component}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $"subgroupBallot({arg}).x";
|
||||
return $"subgroupBallot({arg}).{component}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -108,10 +108,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
|
|||
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
|
||||
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.Shuffle, InstType.CallQuaternary, HelperFunctionNames.Shuffle);
|
||||
Add(Instruction.ShuffleDown, InstType.CallQuaternary, HelperFunctionNames.ShuffleDown);
|
||||
Add(Instruction.ShuffleUp, InstType.CallQuaternary, HelperFunctionNames.ShuffleUp);
|
||||
Add(Instruction.ShuffleXor, InstType.CallQuaternary, HelperFunctionNames.ShuffleXor);
|
||||
Add(Instruction.Shuffle, InstType.Special);
|
||||
Add(Instruction.ShuffleDown, InstType.CallBinary, "subgroupShuffleDown");
|
||||
Add(Instruction.ShuffleUp, InstType.CallBinary, "subgroupShuffleUp");
|
||||
Add(Instruction.ShuffleXor, InstType.CallBinary, "subgroupShuffleXor");
|
||||
Add(Instruction.Sine, InstType.CallUnary, "sin");
|
||||
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
|
||||
Add(Instruction.Store, InstType.Special);
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
using Ryujinx.Graphics.Shader.StructuredIr;
|
||||
using Ryujinx.Graphics.Shader.Translation;
|
||||
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper;
|
||||
|
||||
namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
|
||||
{
|
||||
static class InstGenShuffle
|
||||
{
|
||||
public static string Shuffle(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
string value = GetSoureExpr(context, operation.GetSource(0), AggregateType.FP32);
|
||||
string index = GetSoureExpr(context, operation.GetSource(1), AggregateType.U32);
|
||||
|
||||
if (context.HostCapabilities.SupportsShaderBallot)
|
||||
{
|
||||
return $"readInvocationARB({value}, {index})";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $"subgroupShuffle({value}, {index})";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -231,7 +231,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
|
|||
var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
|
||||
|
||||
var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source));
|
||||
var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)0);
|
||||
var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)operation.Index);
|
||||
|
||||
return new OperationResult(AggregateType.U32, mask);
|
||||
}
|
||||
|
@ -1100,117 +1100,40 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
|
|||
|
||||
private static OperationResult GenerateShuffle(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
var x = context.GetFP32(operation.GetSource(0));
|
||||
var value = context.GetFP32(operation.GetSource(0));
|
||||
var index = context.GetU32(operation.GetSource(1));
|
||||
var mask = context.GetU32(operation.GetSource(2));
|
||||
|
||||
var const31 = context.Constant(context.TypeU32(), 31);
|
||||
var const8 = context.Constant(context.TypeU32(), 8);
|
||||
|
||||
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
|
||||
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
|
||||
var notSegMask = context.Not(context.TypeU32(), segMask);
|
||||
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
|
||||
var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask);
|
||||
|
||||
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
|
||||
|
||||
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
|
||||
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
|
||||
var srcThreadId = context.BitwiseOr(context.TypeU32(), indexNotSegMask, minThreadId);
|
||||
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
|
||||
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
|
||||
var result = context.Select(context.TypeFP32(), valid, value, x);
|
||||
|
||||
var validLocal = (AstOperand)operation.GetSource(3);
|
||||
|
||||
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
|
||||
var result = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
|
||||
|
||||
return new OperationResult(AggregateType.FP32, result);
|
||||
}
|
||||
|
||||
private static OperationResult GenerateShuffleDown(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
var x = context.GetFP32(operation.GetSource(0));
|
||||
var value = context.GetFP32(operation.GetSource(0));
|
||||
var index = context.GetU32(operation.GetSource(1));
|
||||
var mask = context.GetU32(operation.GetSource(2));
|
||||
|
||||
var const31 = context.Constant(context.TypeU32(), 31);
|
||||
var const8 = context.Constant(context.TypeU32(), 8);
|
||||
|
||||
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
|
||||
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
|
||||
var notSegMask = context.Not(context.TypeU32(), segMask);
|
||||
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
|
||||
|
||||
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
|
||||
|
||||
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
|
||||
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
|
||||
var srcThreadId = context.IAdd(context.TypeU32(), threadId, index);
|
||||
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
|
||||
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
|
||||
var result = context.Select(context.TypeFP32(), valid, value, x);
|
||||
|
||||
var validLocal = (AstOperand)operation.GetSource(3);
|
||||
|
||||
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
|
||||
var result = context.GroupNonUniformShuffleDown(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
|
||||
|
||||
return new OperationResult(AggregateType.FP32, result);
|
||||
}
|
||||
|
||||
private static OperationResult GenerateShuffleUp(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
var x = context.GetFP32(operation.GetSource(0));
|
||||
var value = context.GetFP32(operation.GetSource(0));
|
||||
var index = context.GetU32(operation.GetSource(1));
|
||||
var mask = context.GetU32(operation.GetSource(2));
|
||||
|
||||
var const31 = context.Constant(context.TypeU32(), 31);
|
||||
var const8 = context.Constant(context.TypeU32(), 8);
|
||||
|
||||
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
|
||||
|
||||
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
|
||||
|
||||
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
|
||||
var srcThreadId = context.ISub(context.TypeU32(), threadId, index);
|
||||
var valid = context.SGreaterThanEqual(context.TypeBool(), srcThreadId, minThreadId);
|
||||
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
|
||||
var result = context.Select(context.TypeFP32(), valid, value, x);
|
||||
|
||||
var validLocal = (AstOperand)operation.GetSource(3);
|
||||
|
||||
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
|
||||
var result = context.GroupNonUniformShuffleUp(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
|
||||
|
||||
return new OperationResult(AggregateType.FP32, result);
|
||||
}
|
||||
|
||||
private static OperationResult GenerateShuffleXor(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
var x = context.GetFP32(operation.GetSource(0));
|
||||
var value = context.GetFP32(operation.GetSource(0));
|
||||
var index = context.GetU32(operation.GetSource(1));
|
||||
var mask = context.GetU32(operation.GetSource(2));
|
||||
|
||||
var const31 = context.Constant(context.TypeU32(), 31);
|
||||
var const8 = context.Constant(context.TypeU32(), 8);
|
||||
|
||||
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
|
||||
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
|
||||
var notSegMask = context.Not(context.TypeU32(), segMask);
|
||||
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
|
||||
|
||||
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
|
||||
|
||||
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
|
||||
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
|
||||
var srcThreadId = context.BitwiseXor(context.TypeU32(), threadId, index);
|
||||
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
|
||||
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
|
||||
var result = context.Select(context.TypeFP32(), valid, value, x);
|
||||
|
||||
var validLocal = (AstOperand)operation.GetSource(3);
|
||||
|
||||
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
|
||||
var result = context.GroupNonUniformShuffleXor(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
|
||||
|
||||
return new OperationResult(AggregateType.FP32, result);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
|
|||
_poolLock = new object();
|
||||
}
|
||||
|
||||
private const HelperFunctionsMask NeedsInvocationIdMask =
|
||||
HelperFunctionsMask.Shuffle |
|
||||
HelperFunctionsMask.ShuffleDown |
|
||||
HelperFunctionsMask.ShuffleUp |
|
||||
HelperFunctionsMask.ShuffleXor |
|
||||
HelperFunctionsMask.SwizzleAdd;
|
||||
private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd;
|
||||
|
||||
public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue