Implement shader storage buffer operations using new Load/Store instructions (#4993)

* Implement storage buffer operations using new Load/Store instruction

* Extend GenerateMultiTargetStorageOp to also match access with constant offset, and log and comments

* Remove now unused code

* Catch more complex cases of global memory usage

* Shader cache version bump

* Extend global access elimination to work with more shared memory cases

* Change alignment requirement from 16 bytes to 8 bytes, handle cases where we need more than 16 storage buffers

* Tweak preferencing to catch more cases

* Enable CB0 elimination even when host storage buffer alignment is > 16 (for Intel)

* Fix storage buffer bindings

* Simplify some code

* Shader cache version bump

* Fix typo

* Extend global memory elimination to handle shared memory with multiple possible offsets and local memory
This commit is contained in:
gdkchan 2023-06-03 20:12:18 -03:00 committed by GitHub
parent 81c9052847
commit 21c9ac6240
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 1468 additions and 1259 deletions

View file

@ -104,14 +104,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
DeclareConstantBuffers(context, context.Config.Properties.ConstantBuffers.Values);
var sBufferDescriptors = context.Config.GetStorageBufferDescriptors();
if (sBufferDescriptors.Length != 0)
{
DeclareStorages(context, sBufferDescriptors);
context.AppendLine();
}
DeclareStorageBuffers(context, context.Config.Properties.StorageBuffers.Values);
var textureDescriptors = context.Config.GetTextureDescriptors();
if (textureDescriptors.Length != 0)
@ -250,11 +243,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/AtomicMinMaxS32Shared.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.AtomicMinMaxS32Storage) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/AtomicMinMaxS32Storage.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.MultiplyHighS32) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/MultiplyHighS32.glsl");
@ -290,11 +278,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/StoreSharedSmallInt.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.StoreStorageSmallInt) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/StoreStorageSmallInt.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/SwizzleAdd.glsl");
@ -356,6 +339,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
private static void DeclareConstantBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
DeclareBuffers(context, buffers, "uniform");
}
private static void DeclareStorageBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
DeclareBuffers(context, buffers, "buffer");
}
private static void DeclareBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers, string declType)
{
foreach (BufferDefinition buffer in buffers)
{
@ -365,7 +358,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
_ => "std430"
};
context.AppendLine($"layout (binding = {buffer.Binding}, {layout}) uniform _{buffer.Name}");
context.AppendLine($"layout (binding = {buffer.Binding}, {layout}) {declType} _{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
@ -373,9 +366,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
if (field.Type.HasFlag(AggregateType.Array))
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string arraySize = field.ArrayLength.ToString(CultureInfo.InvariantCulture);
context.AppendLine($"{typeName} {field.Name}[{arraySize}];");
if (field.ArrayLength > 0)
{
string arraySize = field.ArrayLength.ToString(CultureInfo.InvariantCulture);
context.AppendLine($"{typeName} {field.Name}[{arraySize}];");
}
else
{
context.AppendLine($"{typeName} {field.Name}[];");
}
}
else
{
@ -390,22 +391,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
}
}
private static void DeclareStorages(CodeGenContext context, BufferDescriptor[] descriptors)
{
string sbName = OperandManager.GetShaderStagePrefix(context.Config.Stage);
sbName += "_" + DefaultNames.StorageNamePrefix;
string blockName = $"{sbName}_{DefaultNames.BlockSuffix}";
string layout = context.Config.Options.TargetApi == TargetApi.Vulkan ? ", set = 1" : string.Empty;
context.AppendLine($"layout (binding = {context.Config.FirstStorageBufferBinding}{layout}, std430) buffer {blockName}");
context.EnterScope();
context.AppendLine("uint " + DefaultNames.DataName + "[];");
context.LeaveScope($" {sbName}[{NumberFormatter.FormatInt(descriptors.Max(x => x.Slot) + 1)}];");
}
private static void DeclareSamplers(CodeGenContext context, TextureDescriptor[] descriptors)
{
int arraySize = 0;
@ -733,7 +718,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
code = code.Replace("\t", CodeGenContext.Tab);
code = code.Replace("$SHARED_MEM$", DefaultNames.SharedMemoryName);
code = code.Replace("$STORAGE_MEM$", OperandManager.GetShaderStagePrefix(context.Config.Stage) + "_" + DefaultNames.StorageNamePrefix);
if (context.Config.GpuAccessor.QueryHostSupportsShaderBallot())
{

View file

@ -11,12 +11,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
public const string IAttributePrefix = "in_attr";
public const string OAttributePrefix = "out_attr";
public const string StorageNamePrefix = "s";
public const string DataName = "data";
public const string BlockSuffix = "block";
public const string LocalMemoryName = "local_mem";
public const string SharedMemoryName = "shared_mem";

View file

@ -1,21 +0,0 @@
int Helper_AtomicMaxS32(int index, int offset, int value)
{
uint oldValue, newValue;
do
{
oldValue = $STORAGE_MEM$[index].data[offset];
newValue = uint(max(int(oldValue), value));
} while (atomicCompSwap($STORAGE_MEM$[index].data[offset], oldValue, newValue) != oldValue);
return int(oldValue);
}
int Helper_AtomicMinS32(int index, int offset, int value)
{
uint oldValue, newValue;
do
{
oldValue = $STORAGE_MEM$[index].data[offset];
newValue = uint(min(int(oldValue), value));
} while (atomicCompSwap($STORAGE_MEM$[index].data[offset], oldValue, newValue) != oldValue);
return int(oldValue);
}

View file

@ -1,23 +0,0 @@
void Helper_StoreStorage16(int index, int offset, uint value)
{
int wordOffset = offset >> 2;
int bitOffset = (offset & 3) * 8;
uint oldValue, newValue;
do
{
oldValue = $STORAGE_MEM$[index].data[wordOffset];
newValue = bitfieldInsert(oldValue, value, bitOffset, 16);
} while (atomicCompSwap($STORAGE_MEM$[index].data[wordOffset], oldValue, newValue) != oldValue);
}
void Helper_StoreStorage8(int index, int offset, uint value)
{
int wordOffset = offset >> 2;
int bitOffset = (offset & 3) * 8;
uint oldValue, newValue;
do
{
oldValue = $STORAGE_MEM$[index].data[wordOffset];
newValue = bitfieldInsert(oldValue, value, bitOffset, 8);
} while (atomicCompSwap($STORAGE_MEM$[index].data[wordOffset], oldValue, newValue) != oldValue);
}

View file

@ -68,33 +68,45 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
string args = string.Empty;
for (int argIndex = 0; argIndex < arity; argIndex++)
if (atomic && operation.StorageKind == StorageKind.StorageBuffer)
{
args = GenerateLoadOrStore(context, operation, isStore: false);
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
args += ", " + GetSoureExpr(context, operation.GetSource(argIndex), dstType);
}
}
else if (atomic && operation.StorageKind == StorageKind.SharedMemory)
{
args = LoadShared(context, operation);
// For shared memory access, the second argument is unused and should be ignored.
// It is there to make both storage and shared access have the same number of arguments.
// For storage, both inputs are consumed when the argument index is 0, so we should skip it here.
if (argIndex == 1 && (atomic || operation.StorageKind == StorageKind.SharedMemory))
{
continue;
}
if (argIndex != 0)
for (int argIndex = 2; argIndex < arity; argIndex++)
{
args += ", ";
}
if (argIndex == 0 && atomic)
AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSoureExpr(context, operation.GetSource(argIndex), dstType);
}
}
else
{
for (int argIndex = 0; argIndex < arity; argIndex++)
{
switch (operation.StorageKind)
if (argIndex != 0)
{
case StorageKind.SharedMemory: args += LoadShared(context, operation); break;
case StorageKind.StorageBuffer: args += LoadStorage(context, operation); break;
default: throw new InvalidOperationException($"Invalid storage kind \"{operation.StorageKind}\".");
args += ", ";
}
}
else
{
AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSoureExpr(context, operation.GetSource(argIndex), dstType);
@ -173,9 +185,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
case Instruction.LoadShared:
return LoadShared(context, operation);
case Instruction.LoadStorage:
return LoadStorage(context, operation);
case Instruction.Lod:
return Lod(context, operation);
@ -203,15 +212,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
case Instruction.StoreShared8:
return StoreShared8(context, operation);
case Instruction.StoreStorage:
return StoreStorage(context, operation);
case Instruction.StoreStorage16:
return StoreStorage16(context, operation);
case Instruction.StoreStorage8:
return StoreStorage8(context, operation);
case Instruction.TextureSample:
return TextureSample(context, operation);

View file

@ -85,7 +85,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
Add(Instruction.Load, InstType.Special);
Add(Instruction.LoadLocal, InstType.Special);
Add(Instruction.LoadShared, InstType.Special);
Add(Instruction.LoadStorage, InstType.Special);
Add(Instruction.Lod, InstType.Special);
Add(Instruction.LogarithmB2, InstType.CallUnary, "log2");
Add(Instruction.LogicalAnd, InstType.OpBinaryCom, "&&", 9);
@ -123,9 +122,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
Add(Instruction.StoreShared, InstType.Special);
Add(Instruction.StoreShared16, InstType.Special);
Add(Instruction.StoreShared8, InstType.Special);
Add(Instruction.StoreStorage, InstType.Special);
Add(Instruction.StoreStorage16, InstType.Special);
Add(Instruction.StoreStorage8, InstType.Special);
Add(Instruction.Subtract, InstType.OpBinary, "-", 2);
Add(Instruction.SwizzleAdd, InstType.CallTernary, HelperFunctionNames.SwizzleAdd);
Add(Instruction.TextureSample, InstType.Special);

View file

@ -210,17 +210,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
return $"{arrayName}[{offsetExpr}]";
}
public static string LoadStorage(CodeGenContext context, AstOperation operation)
{
IAstNode src1 = operation.GetSource(0);
IAstNode src2 = operation.GetSource(1);
string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0));
string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1));
return GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage);
}
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
@ -326,60 +315,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
return $"{HelperFunctionNames.StoreShared8}({offsetExpr}, {src})";
}
public static string StoreStorage(CodeGenContext context, AstOperation operation)
{
IAstNode src1 = operation.GetSource(0);
IAstNode src2 = operation.GetSource(1);
IAstNode src3 = operation.GetSource(2);
string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0));
string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1));
AggregateType srcType = OperandManager.GetNodeDestType(context, src3);
string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32);
string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage);
return $"{sb} = {src}";
}
public static string StoreStorage16(CodeGenContext context, AstOperation operation)
{
IAstNode src1 = operation.GetSource(0);
IAstNode src2 = operation.GetSource(1);
IAstNode src3 = operation.GetSource(2);
string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0));
string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1));
AggregateType srcType = OperandManager.GetNodeDestType(context, src3);
string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32);
string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage);
return $"{HelperFunctionNames.StoreStorage16}({indexExpr}, {offsetExpr}, {src})";
}
public static string StoreStorage8(CodeGenContext context, AstOperation operation)
{
IAstNode src1 = operation.GetSource(0);
IAstNode src2 = operation.GetSource(1);
IAstNode src3 = operation.GetSource(2);
string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0));
string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1));
AggregateType srcType = OperandManager.GetNodeDestType(context, src3);
string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32);
string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage);
return $"{HelperFunctionNames.StoreStorage8}({indexExpr}, {offsetExpr}, {src})";
}
public static string TextureSample(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
@ -701,25 +636,34 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
}
}
private static string GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore)
public static string GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore)
{
StorageKind storageKind = operation.StorageKind;
string varName;
AggregateType varType;
int srcIndex = 0;
int inputsCount = isStore ? operation.SourcesCount - 1 : operation.SourcesCount;
bool isStoreOrAtomic = operation.Inst == Instruction.Store || operation.Inst.IsAtomic();
int inputsCount = isStoreOrAtomic ? operation.SourcesCount - 1 : operation.SourcesCount;
if (operation.Inst == Instruction.AtomicCompareAndSwap)
{
inputsCount--;
}
switch (storageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (!(operation.GetSource(srcIndex++) is AstOperand bindingIndex) || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
int binding = bindingIndex.Value;
BufferDefinition buffer = context.Config.Properties.ConstantBuffers[binding];
BufferDefinition buffer = storageKind == StorageKind.ConstantBuffer
? context.Config.Properties.ConstantBuffers[binding]
: context.Config.Properties.StorageBuffers[binding];
if (!(operation.GetSource(srcIndex++) is AstOperand fieldIndex) || fieldIndex.Type != OperandType.Constant)
{
@ -825,15 +769,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
return varName;
}
private static string GetStorageBufferAccessor(string slotExpr, string offsetExpr, ShaderStage stage)
{
string sbName = OperandManager.GetShaderStagePrefix(stage);
sbName += "_" + DefaultNames.StorageNamePrefix;
return $"{sbName}[{slotExpr}].{DefaultNames.DataName}[{offsetExpr}]";
}
private static string GetMask(int index)
{
return $".{"rgba".AsSpan(index, 1)}";

View file

@ -118,6 +118,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
switch (operation.StorageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (!(operation.GetSource(0) is AstOperand bindingIndex) || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
@ -128,7 +129,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
BufferDefinition buffer = context.Config.Properties.ConstantBuffers[bindingIndex.Value];
BufferDefinition buffer = operation.StorageKind == StorageKind.ConstantBuffer
? context.Config.Properties.ConstantBuffers[bindingIndex.Value]
: context.Config.Properties.StorageBuffers[bindingIndex.Value];
StructureField field = buffer.Type.Fields[fieldIndex.Value];
return field.Type & AggregateType.ElementTypeMask;

View file

@ -24,7 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public int InputVertices { get; }
public Dictionary<int, Instruction> ConstantBuffers { get; } = new Dictionary<int, Instruction>();
public Instruction StorageBuffersArray { get; set; }
public Dictionary<int, Instruction> StorageBuffers { get; } = new Dictionary<int, Instruction>();
public Instruction LocalMemory { get; set; }
public Instruction SharedMemory { get; set; }
public Dictionary<TextureMeta, SamplerType> SamplersTypes { get; } = new Dictionary<TextureMeta, SamplerType>();
@ -308,7 +308,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{
if ((type & AggregateType.Array) != 0)
{
return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
if (length > 0)
{
return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
}
else
{
return TypeRuntimeArray(GetType(type & ~AggregateType.Array));
}
}
else if ((type & AggregateType.ElementCountMask) != 0)
{

View file

@ -5,6 +5,7 @@ using Ryujinx.Graphics.Shader.Translation;
using Spv.Generator;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Numerics;
using static Spv.Specification;
@ -99,7 +100,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
DeclareConstantBuffers(context, context.Config.Properties.ConstantBuffers.Values);
DeclareStorageBuffers(context, context.Config.GetStorageBufferDescriptors());
DeclareStorageBuffers(context, context.Config.Properties.StorageBuffers.Values);
DeclareSamplers(context, context.Config.GetTextureDescriptors());
DeclareImages(context, context.Config.GetImageDescriptors());
DeclareInputsAndOutputs(context, info);
@ -127,6 +128,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
private static void DeclareConstantBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
DeclareBuffers(context, buffers, isBuffer: false);
}
private static void DeclareStorageBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
DeclareBuffers(context, buffers, isBuffer: true);
}
private static void DeclareBuffers(CodeGenContext context, IEnumerable<BufferDefinition> buffers, bool isBuffer)
{
HashSet<SpvInstruction> decoratedTypes = new HashSet<SpvInstruction>();
@ -155,6 +166,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
context.Decorate(structFieldTypes[fieldIndex], Decoration.ArrayStride, (LiteralInteger)fieldSize);
}
// Zero lengths are assumed to be a "runtime array" (which does not have a explicit length
// specified on the shader, and instead assumes the bound buffer length).
// It is only valid as the last struct element.
Debug.Assert(field.ArrayLength > 0 || fieldIndex == buffer.Type.Fields.Length - 1);
offset += fieldSize * field.ArrayLength;
}
else
@ -163,56 +180,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
}
var ubStructType = context.TypeStruct(false, structFieldTypes);
var structType = context.TypeStruct(false, structFieldTypes);
if (decoratedTypes.Add(ubStructType))
if (decoratedTypes.Add(structType))
{
context.Decorate(ubStructType, Decoration.Block);
context.Decorate(structType, isBuffer ? Decoration.BufferBlock : Decoration.Block);
for (int fieldIndex = 0; fieldIndex < structFieldOffsets.Length; fieldIndex++)
{
context.MemberDecorate(ubStructType, fieldIndex, Decoration.Offset, (LiteralInteger)structFieldOffsets[fieldIndex]);
context.MemberDecorate(structType, fieldIndex, Decoration.Offset, (LiteralInteger)structFieldOffsets[fieldIndex]);
}
}
var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType);
var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform);
var pointerType = context.TypePointer(StorageClass.Uniform, structType);
var variable = context.Variable(pointerType, StorageClass.Uniform);
context.Name(ubVariable, buffer.Name);
context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)buffer.Set);
context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)buffer.Binding);
context.AddGlobalVariable(ubVariable);
context.ConstantBuffers.Add(buffer.Binding, ubVariable);
context.Name(variable, buffer.Name);
context.Decorate(variable, Decoration.DescriptorSet, (LiteralInteger)buffer.Set);
context.Decorate(variable, Decoration.Binding, (LiteralInteger)buffer.Binding);
context.AddGlobalVariable(variable);
if (isBuffer)
{
context.StorageBuffers.Add(buffer.Binding, variable);
}
else
{
context.ConstantBuffers.Add(buffer.Binding, variable);
}
}
}
private static void DeclareStorageBuffers(CodeGenContext context, BufferDescriptor[] descriptors)
{
if (descriptors.Length == 0)
{
return;
}
int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 1 : 0;
int count = descriptors.Max(x => x.Slot) + 1;
var sbArrayType = context.TypeRuntimeArray(context.TypeU32());
context.Decorate(sbArrayType, Decoration.ArrayStride, (LiteralInteger)4);
var sbStructType = context.TypeStruct(true, sbArrayType);
context.Decorate(sbStructType, Decoration.BufferBlock);
context.MemberDecorate(sbStructType, 0, Decoration.Offset, (LiteralInteger)0);
var sbStructArrayType = context.TypeArray(sbStructType, context.Constant(context.TypeU32(), count));
var sbPointerType = context.TypePointer(StorageClass.Uniform, sbStructArrayType);
var sbVariable = context.Variable(sbPointerType, StorageClass.Uniform);
context.Name(sbVariable, $"{GetStagePrefix(context.Config.Stage)}_s");
context.Decorate(sbVariable, Decoration.DescriptorSet, (LiteralInteger)setIndex);
context.Decorate(sbVariable, Decoration.Binding, (LiteralInteger)context.Config.FirstStorageBufferBinding);
context.AddGlobalVariable(sbVariable);
context.StorageBuffersArray = sbVariable;
}
private static void DeclareSamplers(CodeGenContext context, TextureDescriptor[] descriptors)
{
foreach (var descriptor in descriptors)

View file

@ -99,7 +99,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.Load, GenerateLoad);
Add(Instruction.LoadLocal, GenerateLoadLocal);
Add(Instruction.LoadShared, GenerateLoadShared);
Add(Instruction.LoadStorage, GenerateLoadStorage);
Add(Instruction.Lod, GenerateLod);
Add(Instruction.LogarithmB2, GenerateLogarithmB2);
Add(Instruction.LogicalAnd, GenerateLogicalAnd);
@ -137,9 +136,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Add(Instruction.StoreShared, GenerateStoreShared);
Add(Instruction.StoreShared16, GenerateStoreShared16);
Add(Instruction.StoreShared8, GenerateStoreShared8);
Add(Instruction.StoreStorage, GenerateStoreStorage);
Add(Instruction.StoreStorage16, GenerateStoreStorage16);
Add(Instruction.StoreStorage8, GenerateStoreStorage8);
Add(Instruction.Subtract, GenerateSubtract);
Add(Instruction.SwizzleAdd, GenerateSwizzleAdd);
Add(Instruction.TextureSample, GenerateTextureSample);
@ -889,14 +885,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return new OperationResult(AggregateType.U32, value);
}
private static OperationResult GenerateLoadStorage(CodeGenContext context, AstOperation operation)
{
var elemPointer = GetStorageElemPointer(context, operation);
var value = context.Load(context.TypeU32(), elemPointer);
return new OperationResult(AggregateType.U32, value);
}
private static OperationResult GenerateLod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
@ -1307,28 +1295,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
return OperationResult.Invalid;
}
private static OperationResult GenerateStoreStorage(CodeGenContext context, AstOperation operation)
{
var elemPointer = GetStorageElemPointer(context, operation);
context.Store(elemPointer, context.Get(AggregateType.U32, operation.GetSource(2)));
return OperationResult.Invalid;
}
private static OperationResult GenerateStoreStorage16(CodeGenContext context, AstOperation operation)
{
GenerateStoreStorageSmallInt(context, operation, 16);
return OperationResult.Invalid;
}
private static OperationResult GenerateStoreStorage8(CodeGenContext context, AstOperation operation)
{
GenerateStoreStorageSmallInt(context, operation, 8);
return OperationResult.Invalid;
}
private static OperationResult GenerateSubtract(CodeGenContext context, AstOperation operation)
{
return GenerateBinary(context, operation, context.Delegates.FSub, context.Delegates.ISub);
@ -1849,13 +1815,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
AstOperation operation,
Func<SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction, SpvInstruction> emitU)
{
var value = context.GetU32(operation.GetSource(2));
var value = context.GetU32(operation.GetSource(operation.SourcesCount - 1));
SpvInstruction elemPointer;
if (operation.StorageKind == StorageKind.StorageBuffer)
{
elemPointer = GetStorageElemPointer(context, operation);
elemPointer = GetStoragePointer(context, operation, out _);
}
else if (operation.StorageKind == StorageKind.SharedMemory)
{
@ -1875,14 +1841,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
private static OperationResult GenerateAtomicMemoryCas(CodeGenContext context, AstOperation operation)
{
var value0 = context.GetU32(operation.GetSource(2));
var value1 = context.GetU32(operation.GetSource(3));
var value0 = context.GetU32(operation.GetSource(operation.SourcesCount - 2));
var value1 = context.GetU32(operation.GetSource(operation.SourcesCount - 1));
SpvInstruction elemPointer;
if (operation.StorageKind == StorageKind.StorageBuffer)
{
elemPointer = GetStorageElemPointer(context, operation);
elemPointer = GetStoragePointer(context, operation, out _);
}
else if (operation.StorageKind == StorageKind.SharedMemory)
{
@ -1901,17 +1867,33 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
private static OperationResult GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore)
{
SpvInstruction pointer = GetStoragePointer(context, operation, out AggregateType varType);
if (isStore)
{
context.Store(pointer, context.Get(varType, operation.GetSource(operation.SourcesCount - 1)));
return OperationResult.Invalid;
}
else
{
var result = context.Load(context.GetType(varType), pointer);
return new OperationResult(varType, result);
}
}
private static SpvInstruction GetStoragePointer(CodeGenContext context, AstOperation operation, out AggregateType varType)
{
StorageKind storageKind = operation.StorageKind;
StorageClass storageClass;
SpvInstruction baseObj;
AggregateType varType;
int srcIndex = 0;
switch (storageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (!(operation.GetSource(srcIndex++) is AstOperand bindingIndex) || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
@ -1922,12 +1904,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
BufferDefinition buffer = context.Config.Properties.ConstantBuffers[bindingIndex.Value];
BufferDefinition buffer = storageKind == StorageKind.ConstantBuffer
? context.Config.Properties.ConstantBuffers[bindingIndex.Value]
: context.Config.Properties.StorageBuffers[bindingIndex.Value];
StructureField field = buffer.Type.Fields[fieldIndex.Value];
storageClass = StorageClass.Uniform;
varType = field.Type & AggregateType.ElementTypeMask;
baseObj = context.ConstantBuffers[bindingIndex.Value];
baseObj = storageKind == StorageKind.ConstantBuffer
? context.ConstantBuffers[bindingIndex.Value]
: context.StorageBuffers[bindingIndex.Value];
break;
case StorageKind.Input:
@ -1993,7 +1979,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
throw new InvalidOperationException($"Invalid storage kind {storageKind}.");
}
int inputsCount = (isStore ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
bool isStoreOrAtomic = operation.Inst == Instruction.Store || operation.Inst.IsAtomic();
int inputsCount = (isStoreOrAtomic ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
if (operation.Inst == Instruction.AtomicCompareAndSwap)
{
inputsCount--;
}
SpvInstruction e0, e1, e2;
SpvInstruction pointer;
@ -2030,16 +2023,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
break;
}
if (isStore)
{
context.Store(pointer, context.Get(varType, operation.GetSource(srcIndex)));
return OperationResult.Invalid;
}
else
{
var result = context.Load(context.GetType(varType), pointer);
return new OperationResult(varType, result);
}
return pointer;
}
private static SpvInstruction GetScalarInput(CodeGenContext context, IoVariable ioVariable)
@ -2068,25 +2052,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize);
}
private static void GenerateStoreStorageSmallInt(CodeGenContext context, AstOperation operation, int bitSize)
{
var i0 = context.Get(AggregateType.S32, operation.GetSource(0));
var offset = context.Get(AggregateType.U32, operation.GetSource(1));
var value = context.Get(AggregateType.U32, operation.GetSource(2));
var wordOffset = context.ShiftRightLogical(context.TypeU32(), offset, context.Constant(context.TypeU32(), 2));
var bitOffset = context.BitwiseAnd(context.TypeU32(), offset, context.Constant(context.TypeU32(), 3));
bitOffset = context.ShiftLeftLogical(context.TypeU32(), bitOffset, context.Constant(context.TypeU32(), 3));
var sbVariable = context.StorageBuffersArray;
var i1 = context.Constant(context.TypeS32(), 0);
var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeU32()), sbVariable, i0, i1, wordOffset);
GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize);
}
private static void GenerateStoreSmallInt(
CodeGenContext context,
SpvInstruction elemPointer,
@ -2173,16 +2138,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
}
}
private static SpvInstruction GetStorageElemPointer(CodeGenContext context, AstOperation operation)
{
var sbVariable = context.StorageBuffersArray;
var i0 = context.Get(AggregateType.S32, operation.GetSource(0));
var i1 = context.Constant(context.TypeS32(), 0);
var i2 = context.Get(AggregateType.S32, operation.GetSource(1));
return context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeU32()), sbVariable, i0, i1, i2);
}
private static OperationResult GenerateUnary(
CodeGenContext context,
AstOperation operation,