Skip to content

Commit

Permalink
Shader decompiler: More control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
wheremyfoodat committed Aug 19, 2024
1 parent 7e04ab7 commit e481ce8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 20 deletions.
5 changes: 4 additions & 1 deletion include/PICA/shader_decompiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ namespace PICA::ShaderGen {
bool operator<(const Function& other) const { return AddressRange(start, end) < AddressRange(other.start, other.end); }

std::string getIdentifier() const { return fmt::format("fn_{}_{}", start, end); }
std::string getForwardDecl() const { return fmt::format("void fn_{}_{}();\n", start, end); }
// To handle weird control flow, we have to return from each function a bool that indicates whether or not the shader reached an end
// instruction and should thus terminate. This is necessary for games like Rayman and Gravity Falls, which have "END" instructions called
// from within functions deep in the callstack
std::string getForwardDecl() const { return fmt::format("bool fn_{}_{}();\n", start, end); }
std::string getCallStatement() const { return fmt::format("fn_{}_{}()", start, end); }
};

Expand Down
95 changes: 76 additions & 19 deletions src/core/PICA/shader_decompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,33 @@ ExitMode ControlFlow::analyzeFunction(const PICAShader& shader, u32 start, u32 e
}

// Exit mode of the remainder of this function, after we return from the callee
ExitMode postCallExitMode = analyzeFunction(shader, pc + 1, end, labels);
ExitMode exitMode = exitSeries(postCallExitMode, calledFunction->exitMode);
const ExitMode postCallExitMode = analyzeFunction(shader, pc + 1, end, labels);
const ExitMode exitMode = exitSeries(postCallExitMode, calledFunction->exitMode);

it->second = exitMode;
return exitMode;
}
case ShaderOpcodes::CALLC: Helpers::panic("Unimplemented control flow operation (CALLC)"); break;
case ShaderOpcodes::CALLU: Helpers::panic("Unimplemented control flow operation (CALLU)"); break;

case ShaderOpcodes::CALLC:
case ShaderOpcodes::CALLU: {
const u32 num = instruction & 0xff;
const u32 dest = getBits<10, 12>(instruction);
const Function* calledFunction = addFunction(shader, dest, dest + num);

// Check if analysis of the branch taken func failed and return unknown if it did
if (analysisFailed) {
it->second = ExitMode::Unknown;
return it->second;
}

// Exit mode of the remainder of this function, after we return from the callee
const ExitMode postCallExitMode = analyzeFunction(shader, pc + 1, end, labels);
const ExitMode exitMode = exitSeries(exitParallel(calledFunction->exitMode, ExitMode::AlwaysReturn), postCallExitMode);

it->second = exitMode;
return exitMode;
}

case ShaderOpcodes::LOOP: {
u32 dest = getBits<10, 12>(instruction);
const Function* loopFunction = addFunction(shader, pc + 1, dest + 1);
Expand All @@ -159,13 +178,13 @@ ExitMode ControlFlow::analyzeFunction(const PICAShader& shader, u32 start, u32 e
return it->second;
}

ExitMode afterLoop = analyzeFunction(shader, dest + 1, end, labels);
ExitMode exitMode = exitSeries(afterLoop, loopFunction->exitMode);
const ExitMode afterLoop = analyzeFunction(shader, dest + 1, end, labels);
const ExitMode exitMode = exitSeries(afterLoop, loopFunction->exitMode);
it->second = exitMode;
return it->second;
}
case ShaderOpcodes::END: it->second = ExitMode::AlwaysEnd; return it->second;

case ShaderOpcodes::END: it->second = ExitMode::AlwaysEnd; return it->second;
default: break;
}
}
Expand Down Expand Up @@ -251,23 +270,28 @@ std::string ShaderDecompiler::decompile() {
decompiledShader += func.getForwardDecl();
}

decompiledShader += "void pica_shader_main() {\n";
decompiledShader += "bool pica_shader_main() {\n";
AddressRange mainFunctionRange(entrypoint, PICAShader::maxInstructionCount);
callFunction(*findFunction(mainFunctionRange));
decompiledShader += "}\n";
decompiledShader += "return true;\n}\n";

for (const Function& func : controlFlow.functions) {
if (func.outLabels.empty()) {
decompiledShader += fmt::format("void {}() {{\n", func.getIdentifier());
compileRange(AddressRange(func.start, func.end));
decompiledShader += fmt::format("bool {}() {{\n", func.getIdentifier());

auto [pc, finished] = compileRange(AddressRange(func.start, func.end));
if (!finished) {
decompiledShader += "return false;";
}

decompiledShader += "}\n";
} else {
auto labels = func.outLabels;
labels.insert(func.start);

// If a function has jumps and "labels", this needs to be emulated using a switch-case, with the variable being switched on being the
// current PC
decompiledShader += fmt::format("void {}() {{\n", func.getIdentifier());
decompiledShader += fmt::format("bool {}() {{\n", func.getIdentifier());
decompiledShader += fmt::format("uint pc = {}u;\n", func.start);
decompiledShader += "while(true){\nswitch(pc){\n";

Expand All @@ -287,12 +311,12 @@ std::string ShaderDecompiler::decompile() {
decompiledShader += "}\n";
}

decompiledShader += "default: return;\n";
decompiledShader += "default: return false;\n";
// Exit the switch and loop
decompiledShader += "} }\n";

// Exit the function
decompiledShader += "return;\n";
decompiledShader += "return false;\n";
decompiledShader += "}\n";
}
}
Expand Down Expand Up @@ -613,12 +637,35 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
return;
}

case ShaderOpcodes::CALL: {
case ShaderOpcodes::CALL:
case ShaderOpcodes::CALLC:
case ShaderOpcodes::CALLU: {
const u32 num = instruction & 0xff;
const u32 dest = getBits<10, 12>(instruction);
const Function* calledFunc = findFunction(AddressRange(dest, dest + num));

// Handle conditions for CALLC/CALLU
if (opcode == ShaderOpcodes::CALLC) {
const u32 condOp = getBits<22, 2>(instruction);
const uint refY = getBit<24>(instruction);
const uint refX = getBit<25>(instruction);
const char* condition = getCondition(condOp, refX, refY);

decompiledShader += fmt::format("if ({}) {{", condition);
} else if (opcode == ShaderOpcodes::CALLU) {
const u32 bit = getBits<22, 4>(instruction); // Bit of the bool uniform to check
const u32 mask = 1u << bit;

decompiledShader += fmt::format("if ((uniform_bool & {}u) != 0u) {{", mask);
}

callFunction(*calledFunc);

// Close brackets for CALLC/CALLU
if (opcode != ShaderOpcodes::CALL) {
decompiledShader += "}";
}

if (opcode == ShaderOpcodes::CALL && calledFunc->exitMode == ExitMode::AlwaysEnd) {
finished = true;
return;
Expand Down Expand Up @@ -651,7 +698,7 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
}

case ShaderOpcodes::END:
decompiledShader += "return;\n";
decompiledShader += "return true;\n";
finished = true;
return;

Expand Down Expand Up @@ -686,13 +733,23 @@ bool ShaderDecompiler::usesCommonEncoding(u32 instruction) const {
case ShaderOpcodes::SLT:
case ShaderOpcodes::SLTI:
case ShaderOpcodes::SGE:
case ShaderOpcodes::SGEI: return true;
case ShaderOpcodes::SGEI:
case ShaderOpcodes::LITP: return true;

default: return false;
}
}

void ShaderDecompiler::callFunction(const Function& function) { decompiledShader += function.getCallStatement() + ";\n"; }
void ShaderDecompiler::callFunction(const Function& function) {
switch (function.exitMode) {
// This function always ends, so call it and return true to signal that we're gonna be ending the shader
case ExitMode::AlwaysEnd: decompiledShader += function.getCallStatement() + ";\nreturn true;\n"; break;
// This function will potentially end. Call it, see if it returns that it ended, and return that we're ending if it did
case ExitMode::Conditional: decompiledShader += fmt::format("if ({}) {{ return true; }}\n", function.getCallStatement()); break;
// This function will not end. Just call it like a normal function.
default: decompiledShader += function.getCallStatement() + ";\n"; break;
}
}

std::string ShaderGen::decompileShader(PICAShader& shader, EmulatorConfig& config, u32 entrypoint, API api, Language language) {
ShaderDecompiler decompiler(shader, config, entrypoint, api, language);
Expand Down Expand Up @@ -726,7 +783,7 @@ const char* ShaderDecompiler::getCondition(u32 cond, u32 refX, u32 refY) {
"cmp_reg.x",
"cmp_reg.y",
};
u32 key = (cond & 0b11) | (refX << 2) | (refY << 3);
const u32 key = (cond & 0b11) | (refX << 2) | (refY << 3);

return conditions[key];
}

0 comments on commit e481ce8

Please sign in to comment.