diff --git a/cpp/.gitignore b/cpp/.gitignore new file mode 100644 index 0000000..dd81521 --- /dev/null +++ b/cpp/.gitignore @@ -0,0 +1,2 @@ +build +cmake-* \ No newline at end of file diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 0000000..6162677 --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,132 @@ +cmake_minimum_required(VERSION 3.20) +project(rspl VERSION 1.0.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Dependencies via FetchContent +include(FetchContent) + +FetchContent_Declare( + nlohmann_json + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG v3.11.3 +) +FetchContent_MakeAvailable(nlohmann_json) + +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.7.1 +) +FetchContent_MakeAvailable(Catch2) + +# --- Library sources ------------------------------------------------------ + +set(RSPL_SOURCES + src/asm.cpp + src/asm.h + src/asm_normalize.cpp + src/asm_normalize.h + src/asm_writer.cpp + src/asm_writer.h + src/ast.cpp + src/ast.h + src/ast2asm.cpp + src/ast2asm.h + src/astCalcNormalizer.cpp + src/astCalcNormalizer.h + src/builtins.cpp + src/builtins.h + src/operations/branch.cpp + src/operations/branch.h + src/operations/scalar.cpp + src/operations/scalar.h + src/operations/user_function.cpp + src/operations/user_function.h + src/operations/vector.cpp + src/operations/vector.h + src/optimizer/asm_optimizer.cpp + src/optimizer/asm_optimizer.h + src/optimizer/asm_scan_deps.cpp + src/optimizer/asm_scan_deps.h + src/optimizer/eval_cost.cpp + src/optimizer/eval_cost.h + src/pipeline.cpp + src/pipeline.h + src/preproc.cpp + src/preproc.h + src/registers.cpp + src/registers.h + src/state.cpp + src/state.h + src/swizzle.cpp + src/swizzle.h + src/types.cpp + src/types.h +) + +add_library(rspl_core STATIC ${RSPL_SOURCES}) +target_include_directories(rspl_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src) +target_link_libraries(rspl_core PUBLIC nlohmann_json::nlohmann_json) + +# --- CLI executable ------------------------------------------------------- + +add_executable(rspl src/main.cpp) +target_link_libraries(rspl PRIVATE rspl_core) + +# --- Tests ---------------------------------------------------------------- + +set(TEST_SOURCES + tests/test_annotations.cpp + tests/test_ast.cpp + tests/test_branchConst.cpp + tests/test_builtins.cpp + tests/test_builtinsDebug.cpp + tests/test_builtinsAll.cpp + tests/test_branchVar.cpp + tests/test_branchZero.cpp + tests/test_compare.cpp + tests/test_const.cpp + tests/test_control.cpp + tests/test_debugInfo.cpp + tests/test_defineAsm.cpp + tests/test_dma.cpp + tests/test_immediateScalar.cpp + tests/test_labels.cpp + tests/test_load.cpp + tests/test_loop.cpp + tests/test_macros.cpp + tests/test_preproc.cpp + tests/test_scalarOps.cpp + tests/test_scope.cpp + tests/test_state.cpp + tests/test_stateDataBss.cpp + tests/test_store.cpp + tests/test_swizzle.cpp + tests/test_syntaxExpansion.cpp + tests/test_optAssert.cpp + tests/test_optBranchJump.cpp + tests/test_optDeadCode.cpp + tests/test_optDedupeImm.cpp + tests/test_optDelaySlot.cpp + tests/test_optJumpDedupe.cpp + tests/test_optLabels.cpp + tests/test_optDepScanCtrl.cpp + tests/test_optDepScanMem.cpp + tests/test_optDepScanRegs.cpp + tests/test_optMergeSequence.cpp + tests/test_optRegScan.cpp + tests/test_evalCost.cpp + tests/test_evalCostExample.cpp + tests/test_examples.cpp + tests/test_vectorOps.cpp + tests/test_syntaxNumbers.cpp + tests/test_syntaxVar.cpp +) + +add_executable(rspl_tests ${TEST_SOURCES}) +target_link_libraries(rspl_tests PRIVATE rspl_core Catch2::Catch2WithMain) +include(CTest) +add_test(NAME rspl_tests COMMAND rspl_tests) diff --git a/cpp/README.md b/cpp/README.md new file mode 100644 index 0000000..dda5bc1 --- /dev/null +++ b/cpp/README.md @@ -0,0 +1,50 @@ +# RSPL C++ Backend + +Native C++ port of the RSPL transpiler. Parsing is delegated to the existing +JS parser (`scripts/parse.js`); everything downstream runs in C++. + +## Build + +Requirements: **CMake 3.20+**, **g++ 13+** (or clang++ with C++20), **Node.js 20+**. + +```sh +cd cpp + +# Debug build (no optimizations, asserts enabled, debug symbols) +cmake -B build -DCMAKE_BUILD_TYPE=Debug +cmake --build build + +# Production build (-O3, no asserts, stripped) +cmake -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build +``` + +Dependencies (nlohmann/json, Catch2) are fetched automatically by CMake. + +## Run + +From the repo root: + +```sh +cpp/build/rspl input.rspl # full pipeline → stdout +cpp/build/rspl input.rspl -o output.S # write to file +cpp/build/rspl input.rspl --no-optimize # skip optimizer +cpp/build/rspl input.rspl --no-rspq # raw asm, no RSPQ wrapper +cpp/build/rspl input.rspl --ast-dump # dump parsed AST (debug) +cpp/build/rspl input.rspl -D FOO=42 # preprocessor define +cpp/build/rspl input.rspl --reorder # enable instruction reorder annealing +cpp/build/rspl input.rspl --opt-time=60 # optimizer time budget in seconds +``` + +The binary invokes `node scripts/parse.js` internally. Set `RSPL_PARSE_JS` env var +to override the script path: + +```sh +RSPL_PARSE_JS=/path/to/custom/parse.js cpp/build/rspl input.rspl +``` + +## Tests + +```sh +cpp/build/rspl_tests +``` diff --git a/cpp/src/asm.cpp b/cpp/src/asm.cpp new file mode 100644 index 0000000..dce999a --- /dev/null +++ b/cpp/src/asm.cpp @@ -0,0 +1,188 @@ +#include "asm.h" +#include "state.h" + +#include +#include +#include + +namespace rspl { + +// --- Opcode registry --------------------------------------------------- +// Function-local statics to avoid static-init-order fiasco across TUs. + +struct OpcodeRegistry { + std::vector names; + std::unordered_map map; +}; +static OpcodeRegistry &opcodeReg() { + static OpcodeRegistry reg; + return reg; +} + +Opcode getOpcode(const std::string &name) { + if (name.empty()) return 0; + auto ® = opcodeReg(); + auto it = reg.map.find(name); + if (it != reg.map.end()) return it->second; + Opcode idx = static_cast(reg.names.size() + 1); + reg.names.push_back(name); + reg.map[name] = idx; + return idx; +} + +const std::string &getOpcodeName(Opcode op) { + static const std::string empty; + auto &names = opcodeReg().names; + if (op == 0 || op > names.size()) return empty; + return names[op - 1]; +} + +// --- Precomputed opcode → (flags, latency) map -------------------------- +// Replaces the 8 separate unordered_set lookups with a single map lookup. + +struct OpInfoEntry { uint32_t flags; int latency; }; + +static const std::unordered_map OP_INFO_MAP = []() { + std::unordered_map m; + + auto add = [&](const char *op, uint32_t flags, int latency) { + m[getOpcode(op)] = {flags, latency}; + }; + + // Branches + for (auto *op : {"beq","bne","bgezal","bltzal","bgez","bltz","blez","bgtz", + "j","jr","jal"}) + add(op, OP_FLAG_IS_BRANCH | OP_FLAG_IS_IMMOVABLE, 0); + + // Stores (also get MEM_STALL_STORE) + for (auto *op : {"sw","sh","sb","sbv","ssv","slv","sdv","sqv","spv","suv", + "shv","sfv","stv","swv","srv"}) + add(op, OP_FLAG_IS_STORE | OP_FLAG_IS_MEM_STALL_STORE, 0); + + // Vector loads + for (auto *op : {"lbv","lsv","llv","ldv","lqv","lpv","luv","lhv","lfv", + "ltv","lrv"}) + add(op, OP_FLAG_IS_LOAD | OP_FLAG_IS_MEM_STALL_LOAD, 4); + + // Scalar loads + for (auto *op : {"lw","lh","lhu","lb","lbu"}) + add(op, OP_FLAG_IS_LOAD | OP_FLAG_IS_MEM_STALL_LOAD, 3); + + // Stall ops (both load and store stalls) + for (auto *op : {"mfc0","mtc0","mfc2","mtc2","cfc2","ctc2"}) { + uint32_t f = OP_FLAG_IS_MEM_STALL_LOAD | OP_FLAG_IS_MEM_STALL_STORE; + int lat = 3; + if (op[2] == 'c') { // cfc2/ctc2 + f |= OP_FLAG_CTC2_CFC2; + } + if (op[1] == 't' && op[2] == 'c' && op[3] == '2') lat = 4; // mtc2 + add(op, f, lat); + } + + // Special + add("nop", OP_FLAG_IS_NOP | OP_FLAG_IS_IMMOVABLE, 0); + add("catch", OP_FLAG_IS_MEM_STALL_LOAD | OP_FLAG_IS_MEM_STALL_STORE, 0); + + return m; +}(); + +int getStallLatency(Opcode op) { + auto it = OP_INFO_MAP.find(op); + if (it != OP_INFO_MAP.end()) return it->second.latency; + const auto &name = getOpcodeName(op); + if (!name.empty() && name[0] == 'v') return 4; + return 0; +} + +uint32_t getOpFlags(Opcode op) { + uint32_t flags = OP_FLAG_IS_LIKELY; + auto it = OP_INFO_MAP.find(op); + if (it != OP_INFO_MAP.end()) flags |= it->second.flags; + else { + const auto &name = getOpcodeName(op); + if (!name.empty() && name[0] == 'v') flags |= OP_FLAG_IS_VECTOR; + } + if (op == getOpcode("nop")) flags |= OP_FLAG_IS_NOP; + + if ((flags & OP_FLAG_IS_BRANCH) && (flags & OP_FLAG_IS_LIKELY)) + flags |= OP_FLAG_LIKELY_BRANCH; + return flags; +} + +static AsmDebug currentDebug() { + return AsmDebug{.lineRSPL = static_cast(state.line)}; +} + +static void applyOpInfo(AsmInst &inst, Opcode op, + AsmType type) { + inst.type = type; + inst.opFlags = getOpFlags(op); + inst.stallLatency = getStallLatency(op); + // Copy current annotations from state (but don't clear — + // clearing is done per-statement in scopedBlockToAsm to match JS) + for (const auto &ann : state.getAnnotations()) { + inst.cold->annotations.push_back({ann.name, ann.value}); + } +} + +// --- Factory functions ------------------------------------------------ + +AsmInst asmOp(const std::string &op, + const std::vector &args) { + AsmInst inst; + inst.op = getOpcode(op); + inst.args = args; + inst.debug = currentDebug(); + applyOpInfo(inst, inst.op, AsmType::OP); + return inst; +} + +AsmInst asmNOP() { + AsmInst inst; + inst.op = getOpcode("nop"); + inst.debug = currentDebug(); + applyOpInfo(inst, inst.op, AsmType::OP); + return inst; +} + +AsmInst asmLabel(const std::string &label) { + AsmInst inst; + inst.cold->label = label; + inst.op = 0; + inst.debug = currentDebug(); + applyOpInfo(inst, 0, AsmType::LABEL); + return inst; +} + +AsmInst asmBranch(const std::string &op, + const std::vector &args, + const std::string &labelEnd) { + AsmInst inst = asmOp(op, args); + inst.cold->labelEnd = labelEnd; + return inst; +} + +AsmInst asmInline(const std::string &op, + const std::vector &args) { + AsmInst inst; + inst.op = getOpcode(op); + inst.args = args; + inst.debug = currentDebug(); + applyOpInfo(inst, inst.op, AsmType::INLINE); + return inst; +} + +AsmInst asmFunction(const std::string &target, + const std::vector &argRegs, + bool relative) { + if (relative) { + AsmInst inst = asmOp("bgezal", {"$zero", target}); + inst.cold->funcArgs = argRegs; + return inst; + } + AsmInst inst = asmOp("jal", {target}); + inst.cold->funcArgs = argRegs; + return inst; +} + +} // namespace rspl diff --git a/cpp/src/asm.h b/cpp/src/asm.h new file mode 100644 index 0000000..7dabe25 --- /dev/null +++ b/cpp/src/asm.h @@ -0,0 +1,202 @@ +#pragma once + +#include "types.h" + +#include +#include +#include +#include +#include +#include + +namespace rspl { + +// --- ASM type enum ---------------------------------------------------- + +enum AsmType : uint8_t { OP = 0, LABEL = 1, INLINE = 3 }; + +// --- Op flags (bitfield) ---------------------------------------------- + +enum OpFlag : uint32_t { + OP_FLAG_IS_LOAD = 1u << 0, + OP_FLAG_IS_STORE = 1u << 1, + OP_FLAG_IS_BRANCH = 1u << 2, + OP_FLAG_IS_IMMOVABLE = 1u << 3, + OP_FLAG_IS_MEM_STALL_LOAD = 1u << 4, + OP_FLAG_IS_MEM_STALL_STORE = 1u << 5, + OP_FLAG_IS_VECTOR = 1u << 6, + OP_FLAG_IS_NOP = 1u << 7, + OP_FLAG_IS_LIKELY = 1u << 8, + OP_FLAG_LIKELY_BRANCH = 1u << 9, + OP_FLAG_CTC2_CFC2 = 1u << 10, +}; + +// --- Annotation on an instruction ------------------------------------- + +struct AsmAnnotation { + std::string name; + std::string value; +}; + +// --- Debug information ------------------------------------------------ + +struct AsmDebug { + int lineASM = 0; + int lineRSPL = 0; + int lineASMOpt = 0; + int reorderCount = 0; + int reorderLineMin = 0; + int reorderLineMax = 0; + int cycle = 0; + int stall = 0; + bool paired = false; +}; + +// --- Cold metadata — shared between clones via shared_ptr to avoid --------- +// deep-copying label/annotations during optimization cloning. + +struct AsmInstCold { + std::string label; + std::string labelEnd; + std::vector funcArgs; + std::vector annotations; +}; + +// --- Compact opcode representation ------------------------------------ +// Replaces std::string op (32 bytes) with a uint16_t index (2 bytes). +// Opcode strings are interned in a global table; getOpcode() converts +// string → index, getOpcodeName() converts index → string (for output). + +using Opcode = uint16_t; +Opcode getOpcode(const std::string &name); +const std::string &getOpcodeName(Opcode op); + +// Cached Opcode constants for common comparisons (fast after first use). +// Use: inst.op == OP_NOP instead of inst.op == getOpcode("nop"). +namespace Op { + inline Opcode NOP() { static Opcode o = getOpcode("nop"); return o; } + inline Opcode J() { static Opcode o = getOpcode("j"); return o; } + inline Opcode JR() { static Opcode o = getOpcode("jr"); return o; } + inline Opcode JAL() { static Opcode o = getOpcode("jal"); return o; } + inline Opcode BEQ() { static Opcode o = getOpcode("beq"); return o; } + inline Opcode BNE() { static Opcode o = getOpcode("bne"); return o; } + inline Opcode MTC2() { static Opcode o = getOpcode("mtc2"); return o; } + inline Opcode MTC0() { static Opcode o = getOpcode("mtc0"); return o; } + inline Opcode CTC2() { static Opcode o = getOpcode("ctc2"); return o; } + inline Opcode STV() { static Opcode o = getOpcode("stv"); return o; } + inline Opcode LTV() { static Opcode o = getOpcode("ltv"); return o; } + inline Opcode LUI() { static Opcode o = getOpcode("lui"); return o; } + inline Opcode ADDIU() { static Opcode o = getOpcode("addiu"); return o; } + inline Opcode ADDU() { static Opcode o = getOpcode("addu"); return o; } + inline Opcode ORI() { static Opcode o = getOpcode("ori"); return o; } + inline Opcode VMUDL() { static Opcode o = getOpcode("vmudl"); return o; } + inline Opcode VXOR() { static Opcode o = getOpcode("vxor"); return o; } + inline Opcode VRSQH() { static Opcode o = getOpcode("vrsqh"); return o; } + inline Opcode VRCPH() { static Opcode o = getOpcode("vrcph"); return o; } + inline Opcode VRSQL() { static Opcode o = getOpcode("vrsql"); return o; } + inline Opcode VRCPL() { static Opcode o = getOpcode("vrcpl"); return o; } +} + +// --- SmallVec: fixed-capacity vector, no heap allocation --------- + +template struct SmallVec { + std::array data_{}; + uint8_t size_ = 0; + T &operator[](size_t i) { return data_[i]; } + const T &operator[](size_t i) const { return data_[i]; } + T *begin() { return data_.data(); } + T *end() { return data_.data() + size_; } + const T *begin() const { return data_.data(); } + const T *end() const { return data_.data() + size_; } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + void push_back(const T &v) { data_[size_++] = v; } + void clear() { size_ = 0; } + T &back() { return data_[size_ - 1]; } + const T &back() const { return data_[size_ - 1]; } + void pop_back() { --size_; } + SmallVec &operator=(const std::vector &v) { + size_ = 0; + for (const auto &e : v) push_back(e); + return *this; + } +}; + +// --- ASM instruction -------------------------------------------------- + +struct AsmInst { + Opcode op = 0; // mnemonic e.g. "add", "beq", "nop" + std::vector args; // operands + AsmType type = AsmType::OP; + + uint32_t opFlags = 0; // bitfield of OpFlag + int stallLatency = 0; + + AsmDebug debug; + uint32_t barrierMask = 0; + + // -- Dependency tracking (filled by optimizer) ----------------------- + SmallVec depsSourceIdx; + SmallVec depsTargetIdx; + SmallVec depsStallSourceIdx; + SmallVec depsStallTargetIdx; + + // 295-bit register masks stored as 5 x uint64_t + std::array depsSourceMask = {}; + std::array depsTargetMask = {}; + + uint32_t depsStallSourceMask0 = 0; + uint32_t depsStallSourceMask1 = 0; + uint32_t depsStallTargetMask0 = 0; + uint32_t depsStallTargetMask1 = 0; + + // Cold metadata — shallow-copied during clone (shared_ptr refcount bump) + std::shared_ptr cold = std::make_shared(); +}; + +// --- Function-level ASM ----------------------------------------------- + +struct AsmFunc { + std::string name; + FuncType type = FuncType::Function; // function, command, or macro + std::vector asm_; + int argSize = 0; + int cyclesBefore = 0; + int cyclesAfter = 0; + std::vector annotations; // from AST + std::optional resultType; + std::string nameOverride; // for command aliasing + + // Temp iteration count for reorder worker + int _iterCount = 0; +}; + +// --- ASM output ------------------------------------------------------- + +struct AsmOutput { + std::string asm_; + int sizeDMEM = 0; + int sizeIMEM = 0; +}; + +// --- Factory functions ------------------------------------------------ + +AsmInst asmOp(const std::string &op, + const std::vector &args = {}); +AsmInst asmNOP(); +AsmInst asmLabel(const std::string &label); +AsmInst asmBranch(const std::string &op, + const std::vector &args, + const std::string &labelEnd); +AsmInst asmInline(const std::string &op, + const std::vector &args = {}); +AsmInst asmFunction(const std::string &target, + const std::vector &argRegs, + bool relative = false); + +// --- Op classification helpers ---------------------------------------- + +int getStallLatency(Opcode op); +uint32_t getOpFlags(Opcode op); + +} // namespace rspl diff --git a/cpp/src/asm_normalize.cpp b/cpp/src/asm_normalize.cpp new file mode 100644 index 0000000..893fd56 --- /dev/null +++ b/cpp/src/asm_normalize.cpp @@ -0,0 +1,49 @@ +#include "asm_normalize.h" +#include "asm.h" +#include "registers.h" + +#include + +namespace rspl { + +// READ_ONLY_OPS: BRANCH_OPS + STORE_OPS + ["mtc0"] +static const std::unordered_set READ_ONLY_OPS = []() { + std::unordered_set s; + for (auto *op : {"beq","bne","bgezal","bltzal","bgez","bltz", + "blez","bgtz","j","jr","jal", + "sw","sh","sb","sbv","ssv","slv","sdv", + "sqv","spv","suv","shv","sfv","stv","swv","srv", + "mtc0"}) + s.insert(getOpcode(op)); + return s; +}(); + +void normalizeASM(AsmFunc &func) { + std::vector result; + result.reserve(func.asm_.size()); + + for (auto &inst : func.asm_) { + if (inst.type != AsmType::OP || + READ_ONLY_OPS.count(inst.op) || + inst.args.empty()) { + result.push_back(std::move(inst)); + continue; + } + + // Ignore writes to $zero or $vzero (including element-suffixed like $v00.e0) + std::string targetReg = + (inst.op == Op::MTC2()) ? inst.args[1] : inst.args[0]; + auto dotPos = targetReg.find('.'); + std::string baseReg = + (dotPos != std::string::npos) ? targetReg.substr(0, dotPos) : targetReg; + if (baseReg == reg::Reg::ZERO || baseReg == reg::Reg::VZERO) { + continue; // drop this instruction + } + + result.push_back(std::move(inst)); + } + + func.asm_ = std::move(result); +} + +} // namespace rspl diff --git a/cpp/src/asm_normalize.h b/cpp/src/asm_normalize.h new file mode 100644 index 0000000..665b34b --- /dev/null +++ b/cpp/src/asm_normalize.h @@ -0,0 +1,10 @@ +#pragma once + +namespace rspl { + +struct AsmFunc; + +/// Remove instructions that write to $zero or $vzero (dead writes). +void normalizeASM(AsmFunc &func); + +} // namespace rspl diff --git a/cpp/src/asm_writer.cpp b/cpp/src/asm_writer.cpp new file mode 100644 index 0000000..78adacb --- /dev/null +++ b/cpp/src/asm_writer.cpp @@ -0,0 +1,379 @@ +#include "asm_writer.h" +#include "asm.h" +#include "ast.h" +#include "registers.h" +#include "state.h" +#include "types.h" + +#include +#include +#include + +namespace rspl { + +std::string stringifyInstr(const AsmInst &inst) { + if (inst.op == 0) return inst.cold->label + ":"; + if (inst.args.empty()) return getOpcodeName(inst.op); + std::ostringstream ss; + ss << getOpcodeName(inst.op); + for (size_t i = 0; i < inst.args.size(); ++i) { + ss << (i == 0 ? " " : ", ") << inst.args[i]; + } + return ss.str(); +} + +static std::string makePadding(size_t len, size_t target) { + if (len >= target) return " "; + return std::string(target - len, ' '); +} + +AsmWriteResult writeASM(const ast::Program &ast, + const std::vector &functions, + const WriteConfig &config) { + state.func = "(ASM)"; + state.line = 0; + + AsmWriteResult res; + std::ostringstream out; + int asmLine = 0; // physical ASM line count + + auto writeLine = [&](const std::string &line) { + out << line << "\n"; + ++asmLine; + ++state.line; + }; + + auto writeLines = [&](const std::vector &lines) { + for (const auto &l : lines) { + writeLine(l); + } + }; + + writeLine("## Auto-generated file, transpiled with RSPL"); + + // Defines from preprocessor + for (const auto &def : ast.defines) { + writeLine("#define " + def.name + " " + def.value); + } + + // Includes + for (const auto &inc : ast.includes) { + std::string path = inc; + // Strip surrounding quotes if present + if (path.size() >= 2 && path.front() == '"' && path.back() == '"') { + path = path.substr(1, path.size() - 2); + } + bool local = !path.empty() && path[0] == '.'; + writeLine(std::string("#include ") + (local ? "\"" : "<") + path + + (local ? "\"" : ">")); + } + + writeLines({ + "", ".set noreorder", ".set noat", ".set nomacro", "", + }); + + // Undefines + equ for scalar registers + for (const auto ® : reg::REGS_SCALAR) { + writeLine("#undef " + reg.substr(1)); + } + for (size_t i = 0; i < reg::REGS_SCALAR.size(); ++i) { + writeLine(".equ hex." + reg::REGS_SCALAR[i] + ", " + + std::to_string(i)); + } + writeLines({"#define vco 0", "#define vcc 1", "#define vce 2"}); + + writeLines({"", ".data", " RSPQ_BeginOverlayHeader"}); + + // Command list + int maxResultType = -1; + for (const auto &fn : functions) { + if (fn.type == FuncType::Command) { + int rt = fn.resultType.value_or(-1); + if (rt > maxResultType) maxResultType = rt; + } + } + maxResultType = std::max(maxResultType, -1); + + std::vector commandList(maxResultType + 1, + " RSPQ_DefineCommand RSPQ_Loop, 4"); + for (const auto &fn : functions) { + if (fn.type == FuncType::Command && fn.resultType.has_value()) { + std::string name = + fn.nameOverride.empty() ? fn.name : fn.nameOverride; + commandList[fn.resultType.value()] = + " RSPQ_DefineCommand " + name + ", " + + std::to_string(std::max(fn.argSize * 4, 4)); + } + } + writeLines(commandList); + writeLines({" RSPQ_EndOverlayHeader", ""}); + + // State sections + // Collect state vars from AST + std::vector stateVars; + std::vector dataVars; + std::vector bssVars; + + for (const auto &sec : ast.states) { + for (const auto &v : sec.vars) { + if (sec.name == "state" || sec.name.empty()) + stateVars.push_back(v); + else if (sec.name == "data") + dataVars.push_back(v); + else if (sec.name == "bss" || sec.name == "temp_state") + bssVars.push_back(v); + else + stateVars.push_back(v); // default to state + } + } + + int totalSaveByteSize = 0; + int totalTextSize = 0; + + bool hasState = std::any_of(stateVars.begin(), stateVars.end(), + [](auto &v) { return !v.isExtern; }); + + if (hasState) { + writeLine(" RSPQ_BeginSavedState"); + writeLine(" STATE_MEM_START:"); + + for (const auto &sv : stateVars) { + if (sv.isExtern) continue; + int arraySize = 1; + for (auto dim : sv.arraySize) arraySize *= dim; + if (arraySize < 1) arraySize = 1; + int byteSize = (TYPE_SIZE.count(sv.varType) + ? TYPE_SIZE.at(sv.varType) + : 4) * + arraySize; + + int align = TYPE_ALIGNMENT.count(sv.varType) + ? TYPE_ALIGNMENT.at(sv.varType) + : 0; + if (sv.align != 0) { + align = static_cast(std::log2(sv.align)); + } + if (align > 0) { + writeLine(" .align " + std::to_string(align)); + } + + if (sv.value.empty()) { + writeLine(" " + sv.varName + ": .ds.b " + + std::to_string(byteSize)); + } else { + auto asmDefIt = TYPE_ASM_DEF.find(sv.varType); + std::string asmType = (asmDefIt != TYPE_ASM_DEF.end()) + ? asmDefIt->second.type + : "word"; + int asmCount = (asmDefIt != TYPE_ASM_DEF.end()) + ? asmDefIt->second.count + : 1; + int arrayCount = arraySize / asmCount; + if (arrayCount < 1) arrayCount = 1; + // Write data with correct type + int totalCount = asmCount * arrayCount; + std::vector data(totalCount, 0.0); + for (size_t i = 0; i < sv.value.size() && i < static_cast(totalCount); ++i) { + data[i] = sv.value[i]; + } + std::ostringstream ss; + for (int i = 0; i < totalCount; ++i) { + if (i) ss << ", "; + ss << static_cast(data[i]); + } + writeLine(" " + sv.varName + ": ." + asmType + " " + ss.str()); + } + totalSaveByteSize += byteSize; + } + + writeLine(" STATE_MEM_END:"); + writeLine(" RSPQ_EndSavedState"); + } else { + writeLine(" RSPQ_EmptySavedState"); + } + + // Helper to emit a single state var (with alignment and size) + auto emitStateVar = [&](const ast::StateVarDef &sv) { + int arraySize = 1; + for (auto dim : sv.arraySize) arraySize *= dim; + if (arraySize < 1) arraySize = 1; + int byteSize = (TYPE_SIZE.count(sv.varType) + ? TYPE_SIZE.at(sv.varType) + : 4) * + arraySize; + int align = TYPE_ALIGNMENT.count(sv.varType) + ? TYPE_ALIGNMENT.at(sv.varType) + : 0; + if (sv.align != 0) + align = static_cast(std::log2(sv.align)); + if (align > 0) + writeLine(" .align " + std::to_string(align)); + writeLine(" " + sv.varName + ": .ds.b " + + std::to_string(byteSize)); + }; + + // Data section + if (!dataVars.empty()) { + writeLine(""); + for (const auto &dv : dataVars) { + if (dv.isExtern) continue; + emitStateVar(dv); + } + } + + // BSS section + if (!bssVars.empty()) { + writeLine(""); + writeLine(".bss"); + writeLine(" TEMP_STATE_MEM_START:"); + for (const auto &bv : bssVars) { + if (bv.isExtern) continue; + emitStateVar(bv); + } + writeLine(" TEMP_STATE_MEM_END:"); + } + + writeLines({"", ".text", "OVERLAY_CODE_START:", ""}); + + // For non-wrapper output, reset here — the headers above were only + // needed for state.line to advance correctly (matching JS asmWriter.js:184-186) + if (!config.rspqWrapper) { + state.line = 1; + out.str(""); + out.clear(); + } + + // Function bodies + for (const auto &fn : functions) { + if (fn.asm_.empty()) continue; + + // Emit .align from @Align(N) annotation + for (const auto &ann : fn.annotations) { + if (ann.name == "Align" && !ann.value.empty()) { + int alignBytes = std::stoi(ann.value); + int alignExp = static_cast(std::log2(alignBytes)); + if (alignExp > 0 && (1 << alignExp) == alignBytes) { + writeLine(".align " + std::to_string(alignExp)); + } + } + } + + writeLine(fn.name + ":"); + + // Track last cycle for debug info (matching JS asmWriter.js) + int lastCycle = fn.asm_.empty() ? 0 : fn.asm_[0].debug.cycle; + + for (const auto &inst : fn.asm_) { + if (inst.type == AsmType::LABEL) { + std::string tag; + for (const auto &ann : inst.cold->annotations) { + if (ann.name == "Tag") + tag = "TAG_" + ann.value + ": "; + } + writeLine(" " + tag + inst.cold->label + ":"); + } else { + // Build raw instruction string (matching JS stringifyInstr) + std::string rawInstr = stringifyInstr(inst); + + // Determine tag prefix + std::string tag; + for (const auto &ann : inst.cold->annotations) { + if (ann.name == "Tag") + tag = "TAG_" + ann.value + ": "; + } + + std::string instr; + if (config.debugInfo) { + // Pad instruction to 51 chars, then prepend prefix + tag + std::string padded = rawInstr; + if (padded.size() < 51) + padded.append(51 - padded.size(), ' '); + instr = " " + tag + padded; + + // Build debug info string + std::ostringstream di; + if (inst.debug.lineRSPL) { + std::string cycleStr = " ^"; + int cycleDiff = inst.debug.cycle - lastCycle; + if (cycleDiff != 0) { + std::string stars; + if (cycleDiff > 1) + stars.append(cycleDiff - 1, '*'); + cycleStr = stars + std::to_string(inst.debug.cycle); + if (cycleStr.size() < 6) + cycleStr.insert(0, 6 - cycleStr.size(), ' '); + } + std::string lineStr = std::to_string(inst.debug.lineRSPL); + if (lineStr.size() < 4) + lineStr.append(4 - lineStr.size(), ' '); + di << "## L:" << lineStr << " | " << cycleStr << " | "; + if (inst.debug.lineRSPL > 0 && + inst.debug.lineRSPL <= + static_cast(state.sourceLines.size())) { + di << state.sourceLines[inst.debug.lineRSPL - 1]; + } + } + + if (!inst.cold->funcArgs.empty()) { + di << " ## Args: "; + for (size_t i = 0; i < inst.cold->funcArgs.size(); ++i) { + if (i) di << ", "; + di << inst.cold->funcArgs[i]; + } + } + + if (inst.barrierMask) { + std::ostringstream bs; + bs << std::hex << std::uppercase << inst.barrierMask; + di << " ## Barrier: 0x" << bs.str(); + } + + instr += di.str(); + } else { + instr = " " + tag + rawInstr; + } + + writeLine(instr); + totalTextSize += 4; + lastCycle = inst.debug.cycle; + } + } + } + + writeLine(""); + + if (!config.rspqWrapper) { + res.asm_ = out.str(); + return res; + } + + writeLine("OVERLAY_CODE_END:"); + writeLine(""); + + // Register defines + for (size_t i = 0; i < reg::REGS_SCALAR.size(); ++i) { + if (i == 1) continue; // skip $at + writeLine("#define " + reg::REGS_SCALAR[i].substr(1) + " $" + + std::to_string(i)); + } + + writeLines({"", ".set at", ".set macro"}); + + // Post includes + for (const auto &inc : ast.postIncludes) { + std::string path = inc; + if (path.size() >= 2 && path.front() == '"' && path.back() == '"') { + path = path.substr(1, path.size() - 2); + } + bool local = !path.empty() && path[0] == '.'; + writeLine(std::string("#include ") + (local ? "\"" : "<") + path + + (local ? "\"" : ">")); + } + + res.asm_ = out.str(); + res.sizeDMEM = totalSaveByteSize; + res.sizeIMEM = totalTextSize; + return res; +} + +} // namespace rspl diff --git a/cpp/src/asm_writer.h b/cpp/src/asm_writer.h new file mode 100644 index 0000000..4798db8 --- /dev/null +++ b/cpp/src/asm_writer.h @@ -0,0 +1,41 @@ +#pragma once + +#include "asm.h" +#include "ast.h" + +#include +#include +#include +#include + +namespace rspl { + +struct WriteConfig { + bool rspqWrapper = true; + bool debugInfo = true; +}; + +struct AsmWriteResult { + std::string asm_; + int sizeDMEM = 0; + int sizeIMEM = 0; + // lineMap: RSPL source line -> vector of ASM lines + std::unordered_map> lineMap; + // lineDepMap: ASM line -> [min, max] reorder range + std::unordered_map> lineDepMap; + // lineOptMap: original ASM line -> optimized ASM line + std::unordered_map lineOptMap; + // lineCycleMap: optimized ASM line -> cycle + std::unordered_map lineCycleMap; + // lineStallMap: optimized ASM line -> stall count + std::unordered_map lineStallMap; +}; + +AsmWriteResult writeASM(const ast::Program &ast, + const std::vector &functions, + const WriteConfig &config); + +// Extracted for testing: format a single instruction to text +std::string stringifyInstr(const AsmInst &inst); + +} // namespace rspl diff --git a/cpp/src/ast.cpp b/cpp/src/ast.cpp new file mode 100644 index 0000000..29fde5a --- /dev/null +++ b/cpp/src/ast.cpp @@ -0,0 +1,449 @@ +#include "ast.h" + +#include +#include + +namespace rspl::ast { + +using json = nlohmann::json; + +// --- Forward declarations for recursive deserializers ----------------- + +static ScopedBlock parseScopedBlock(const json &j); +static Calc parseCalc(const json &j); +static CompareExpr parseCompareExpr(const json &j); + +// --- Helpers ---------------------------------------------------------- + +static inline std::string optStr(const json &j, const char *key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) return {}; + if (it->is_string()) return it->get(); + if (it->is_number()) return std::to_string(it->get()); + if (it->is_boolean()) return it->get() ? "true" : "false"; + return {}; +} + +static inline std::string jsonAsStr(const json &j) { + if (j.is_string()) return j.get(); + if (j.is_number()) return std::to_string(j.get()); + if (j.is_boolean()) return j.get() ? "true" : "false"; + return j.dump(); +} + +static inline uint32_t optLine(const json &j) { + auto it = j.find("line"); + if (it == j.end() || it->is_null()) return 0; + if (it->is_number()) return it->get(); + return 0; +} + +// --- Leaf types ------------------------------------------------------- + +static ExprNum parseExprNum(const json &j) { + return ExprNum{j.value("value", 0)}; +} + +static ExprVarName parseExprVarName(const json &j) { + return ExprVarName{j.value("value", "")}; +} + +// --- FuncArg ---------------------------------------------------------- + +static FuncArg parseFuncArg(const json &j) { + return FuncArg{ + .type = toArgType(j.value("type", "")), + .value = jsonAsStr(j["value"]), + .swizzle = optStr(j, "swizzle"), + }; +} + +// --- Annotation ------------------------------------------------------- + +static Annotation parseAnnotation(const json &j) { + return Annotation{ + .name = j.value("name", ""), + .value = optStr(j, "value"), + }; +} + +// --- FuncDefArg ------------------------------------------------------- + +static FuncDefArg parseFuncDefArg(const json &j) { + return FuncDefArg{ + .type = toTypeClass(j.value("type", "")), + .reg = optStr(j, "reg"), + .name = j.value("name", ""), + }; +} + +// --- CalcParse -> dispatches on "type" field --------------------------- + +static Calc parseCalc(const json &j) { + std::string type = j.value("type", ""); + if (type == "calcNum") { + // calcNum.right is a plain number, not an object + CalcNum cn; + if (j["right"].is_number()) { + cn.right = ExprNum{j["right"].get()}; + } else if (j["right"].is_object()) { + cn.right = parseExprNum(j["right"]); + } + return cn; + } + if (type == "calcVar") { + return CalcVar{ + .op = optStr(j, "op"), + .right = parseExprVarName(j["right"]), + .swizzleRight = optStr(j, "swizzleRight"), + }; + } + if (type == "calcLR") { + CalcLR lr; + lr.left = parseExprVarName(j["left"]); + lr.op = j.value("op", ""); + lr.swizzleLeft = optStr(j, "swizzleLeft"); + lr.swizzleRight = optStr(j, "swizzleRight"); + // right can be VarName or num + if (j["right"].is_object() && j["right"].value("type", "") == "num") { + lr.rightNum = ExprNum{j["right"].value("value", 0)}; + } else if (j["right"].is_object()) { + lr.rightVarName = j["right"].value("value", ""); + } + return lr; + } + if (type == "calcMulti") { + CalcMulti cm; + if (j["left"].is_object() && j["left"].value("type", "") == "num") { + cm.leftVal = j["left"].value("value", int64_t{0}); + } else { + cm.left = parseExprVarName(j["left"]); + } + cm.swizzleLeft = optStr(j, "swizzleLeft"); + cm.groupStart = j.value("groupStart", 0); + if (j.contains("parts") && j["parts"].is_array()) { + for (const auto &p : j["parts"]) { + CalcMultiPart part; + part.op = p.value("op", ""); + part.swizzleRight = optStr(p, "swizzleRight"); + part.groupStart = p.value("groupStart", 0); + part.groupEnd = p.value("groupEnd", 0); + if (p["right"].is_object() && p["right"].value("type", "") == "num") { + part.rightVal = p["right"].value("value", int64_t{0}); + } else { + part.right = parseExprVarName(p["right"]); + } + cm.parts.push_back(std::move(part)); + } + } + return cm; + } + if (type == "calcFunc") { + CalcFunc cf; + cf.funcName = j.value("funcName", ""); + cf.swizzleRight = optStr(j, "swizzleRight"); + if (j.contains("args") && j["args"].is_array()) { + for (const auto &a : j["args"]) { + cf.args.push_back(parseFuncArg(a)); + } + } + return cf; + } + if (type == "calcCompare") { + CalcCompare cc; + cc.left = j.value("left", ""); + cc.op = j.value("op", ""); + cc.swizzleRight = optStr(j, "swizzleRight"); + if (j["right"].is_number()) { + cc.rightVal = j["right"].get(); + } else if (j["right"].is_object() && j["right"].value("type", "") == "num") { + cc.rightVal = j["right"].value("value", 0.0); + } else if (j["right"].is_object()) { + cc.right = j["right"].value("value", ""); + } else if (j["right"].is_string()) { + cc.right = j["right"].get(); + } + if (j.contains("ternary") && !j["ternary"].is_null()) { + TernaryPart tp; + tp.left = j["ternary"].value("left", ""); + tp.swizzleRight = optStr(j["ternary"], "swizzleRight"); + if (j["ternary"]["right"].is_number()) { + tp.rightVal = j["ternary"]["right"].get(); + } else if (j["ternary"]["right"].is_object() && + j["ternary"]["right"].value("type", "") == "num") { + tp.rightVal = j["ternary"]["right"].value("value", 0.0); + } else if (j["ternary"]["right"].is_object()) { + tp.right = j["ternary"]["right"].value("value", ""); + } else if (j["ternary"]["right"].is_string()) { + tp.right = j["ternary"]["right"].get(); + } + cc.ternary = std::move(tp); + } + return cc; + } + throw std::runtime_error("Unknown calc type: " + type); +} + +// --- CompareExpr ------------------------------------------------------ + +static CompareExpr parseCompareExpr(const json &j) { + return CompareExpr{ + .left = parseFuncArg(j["left"]), + .op = j.value("op", ""), + .right = parseFuncArg(j["right"]), + .line = optLine(j), + }; +} + +// --- ScopedBlock ------------------------------------------------------ + +static ScopedBlock parseScopedBlock(const json &j) { + ScopedBlock block; + block.line = optLine(j); + if (j.contains("statements") && j["statements"].is_array()) { + for (const auto &st : j["statements"]) { + std::string stType = st.value("type", ""); + if (stType == "varDecl") { + block.statements.push_back(StmtVarDecl{ + .varName = st.value("varName", ""), + .varType = st.value("varType", ""), + .reg = optStr(st, "reg"), + .isConst = st.value("isConst", false), + .line = optLine(st), + }); + } else if (stType == "varDeclMulti") { + StmtVarDeclMulti s; + s.varType = st.value("varType", ""); + s.reg = optStr(st, "reg"); + s.isConst = st.value("isConst", false); + s.line = optLine(st); + if (st.contains("varNames") && st["varNames"].is_array()) { + for (const auto &vn : st["varNames"]) { + s.varNames.push_back(vn.is_string() ? vn.get() + : vn.is_number() ? std::to_string(vn.get()) + : ""); + } + } + block.statements.push_back(std::move(s)); + } else if (stType == "varDeclAssign") { + StmtVarDeclAssign s; + s.varType = st.value("varType", ""); + s.reg = optStr(st, "reg"); + s.varName = st.value("varName", ""); + s.isConst = st.value("isConst", false); + s.line = optLine(st); + if (st.contains("calc") && !st["calc"].is_null()) { + s.calc = std::make_unique(parseCalc(st["calc"])); + } + block.statements.push_back(std::move(s)); + } else if (stType == "varUndef") { + block.statements.push_back(StmtVarUndef{ + .varName = st.value("varName", ""), + .line = optLine(st), + }); + } else if (stType == "varAssignCalc") { + StmtVarAssignCalc s; + s.varName = st.value("varName", ""); + s.swizzle = optStr(st, "swizzle"); + s.assignType = st.value("assignType", "="); + s.line = optLine(st); + if (st.contains("calc") && !st["calc"].is_null()) { + s.calc = std::make_unique(parseCalc(st["calc"])); + } + block.statements.push_back(std::move(s)); + } else if (stType == "funcCall") { + StmtFuncCall s; + s.func = st.value("func", ""); + s.line = optLine(st); + if (st.contains("args") && st["args"].is_array()) { + for (const auto &a : st["args"]) { + s.args.push_back(parseFuncArg(a)); + } + } + block.statements.push_back(std::move(s)); + } else if (stType == "labelDecl") { + block.statements.push_back(StmtLabelDecl{ + .name = st.value("name", ""), + .line = optLine(st), + }); + } else if (stType == "goto") { + block.statements.push_back(StmtGoto{ + .label = st.value("label", ""), + .line = optLine(st), + }); + } else if (stType == "if") { + StmtIf s; + s.compare = parseCompareExpr(st["compare"]); + s.line = optLine(st); + if (st.contains("blockIf") && !st["blockIf"].is_null()) { + s.blockIf = std::make_unique(parseScopedBlock(st["blockIf"])); + } + if (st.contains("blockElse") && !st["blockElse"].is_null()) { + s.blockElse = + std::make_unique(parseScopedBlock(st["blockElse"])); + } + block.statements.push_back(std::move(s)); + } else if (stType == "while") { + StmtWhile s; + s.compare = parseCompareExpr(st["compare"]); + s.line = optLine(st); + if (st.contains("block") && !st["block"].is_null()) { + s.block = std::make_unique(parseScopedBlock(st["block"])); + } + block.statements.push_back(std::move(s)); + } else if (stType == "loop") { + StmtLoop s; + s.line = optLine(st); + if (st.contains("compare") && !st["compare"].is_null()) { + s.compare = parseCompareExpr(st["compare"]); + } + if (st.contains("block") && !st["block"].is_null()) { + s.block = std::make_unique(parseScopedBlock(st["block"])); + } + block.statements.push_back(std::move(s)); + } else if (stType == "break") { + block.statements.push_back(StmtBreak{optLine(st)}); + } else if (stType == "continue") { + block.statements.push_back(StmtContinue{optLine(st)}); + } else if (stType == "exit") { + block.statements.push_back(StmtExit{optLine(st)}); + } else if (stType == "annotation") { + StmtAnnotation s; + s.name = st.value("name", ""); + s.value = st.contains("value") && !st["value"].is_null() + ? st["value"].is_string() ? st["value"].get() + : st["value"].dump() + : ""; + s.line = optLine(st); + block.statements.push_back(std::move(s)); + } else if (stType == "scopedBlock") { + // Keep as nested scoped block — scope boundaries matter for registers + StmtScopedBlock sb; + sb.line = optLine(st); + sb.body = std::make_unique(parseScopedBlock(st)); + block.statements.push_back(std::move(sb)); + } else if (stType == "nestedCalc") { + // Flatten nestedCalc (synthetic node, not a real scope) + ScopedBlock nested = parseScopedBlock(st); + for (auto &ns : nested.statements) { + block.statements.push_back(std::move(ns)); + } + } else if (stType == "varDeclAlias") { + StmtVarDeclAlias s; + s.line = optLine(st); + s.aliasName = st.value("aliasName", ""); + s.varName = st.value("varName", ""); + block.statements.push_back(std::move(s)); + } else if (stType == "nestedCalc") { + // Synthetic node from astCalcNormalizer — contains its own + // scopedBlock-like statements. Flatten them in. + ScopedBlock nested = parseScopedBlock(st); + for (auto &ns : nested.statements) { + block.statements.push_back(std::move(ns)); + } + } else { + // Unknown statement type — skip with warning + fprintf(stderr, "Warning: unknown statement type '%s', skipping\n", + stType.c_str()); + } + } + } + return block; +} + +// --- State section ---------------------------------------------------- + +static StateVarDef parseStateVarDef(const json &j) { + StateVarDef sv; + sv.varType = j.value("varType", ""); + sv.varName = j.value("varName", ""); + sv.isExtern = j.value("extern", false); + sv.align = j.value("align", int64_t{0}); + if (j.contains("arraySize") && j["arraySize"].is_array()) { + for (const auto &a : j["arraySize"]) { + sv.arraySize.push_back(a.get()); + } + } + if (j.contains("value") && j["value"].is_array()) { + for (const auto &v : j["value"]) { + sv.value.push_back(v.get()); + } + } + return sv; +} + +static StateSection parseStateSection(const json &j) { + StateSection sec; + sec.name = j.value("name", ""); + if (j.contains("vars") && j["vars"].is_array()) { + for (const auto &v : j["vars"]) { + sec.vars.push_back(parseStateVarDef(v)); + } + } + return sec; +} + +// --- Function --------------------------------------------------------- + +static Function parseFunction(const json &j) { + Function func; + if (j.contains("annotations") && j["annotations"].is_array()) { + for (const auto &a : j["annotations"]) { + func.annotations.push_back(parseAnnotation(a)); + } + } + func.type = toFuncType(j.value("type", "function")); + if (j.contains("resultType") && !j["resultType"].is_null()) { + func.resultType = j["resultType"].get(); + } + func.name = j.value("name", ""); + if (j.contains("args") && j["args"].is_array()) { + for (const auto &a : j["args"]) { + func.args.push_back(parseFuncDefArg(a)); + } + } + if (j.contains("body") && !j["body"].is_null()) { + func.body = std::make_unique(parseScopedBlock(j["body"])); + } + return func; +} + +// --- Program (top-level) ---------------------------------------------- + +Program parseJson(const std::string &jsonStr) { + json j = json::parse(jsonStr); + Program prog; + + if (j.contains("includes") && j["includes"].is_array()) { + for (const auto &inc : j["includes"]) { + prog.includes.push_back(inc.get()); + } + } + if (j.contains("states") && j["states"].is_array()) { + for (const auto &s : j["states"]) { + prog.states.push_back(parseStateSection(s)); + } + } + if (j.contains("functions") && j["functions"].is_array()) { + for (const auto &f : j["functions"]) { + prog.functions.push_back(parseFunction(f)); + } + } + if (j.contains("postIncludes") && j["postIncludes"].is_array()) { + for (const auto &inc : j["postIncludes"]) { + prog.postIncludes.push_back(inc.get()); + } + } + if (j.contains("defines") && j["defines"].is_object()) { + for (const auto &[name, def] : j["defines"].items()) { + DefineEntry entry; + entry.name = name; + entry.value = def.contains("value") ? def["value"].get() : def.get(); + prog.defines.push_back(entry); + } + } + + return prog; +} + +} // namespace rspl::ast diff --git a/cpp/src/ast.h b/cpp/src/ast.h new file mode 100644 index 0000000..91938df --- /dev/null +++ b/cpp/src/ast.h @@ -0,0 +1,311 @@ +#pragma once + +#include "types.h" + +#include +#include +#include +#include +#include +#include + +namespace rspl::ast { + +using rspl::ArgType; +using rspl::FuncType; +using rspl::TypeClass; + +// --- Leaf / argument types -------------------------------------------- + +struct ExprNum { + double value = 0.0; +}; + +struct ExprVarName { + std::string value; +}; + +// --- Function arguments ----------------------------------------------- + +struct FuncArg { + ArgType type = ArgType::Var; // var, num, or string + std::string value; + std::string swizzle; // optional, empty if absent +}; + +// --- Function definition argument ------------------------------------- + +struct FuncDefArg { + TypeClass type = TypeClass::Unknown; // data type e.g. u32, vec16 + std::string reg; // optional register constraint, e.g. "$t0" + std::string name; +}; + +// --- Annotations ------------------------------------------------------ + +struct Annotation { + std::string name; + std::string value; +}; + +// --- Comparison expression (in if/while/loop conditions) -------------- + +struct CompareExpr { + FuncArg left; + std::string op; + FuncArg right; + uint32_t line = 0; +}; + +// --- Calculation types ------------------------------------------------ + +struct CalcNum { + ExprNum right; +}; + +struct CalcVar { + std::string op; // "!" / "~" / empty + ExprVarName right; + std::string swizzleRight; +}; + +struct CalcLR { + ExprVarName left; + std::string op; + ExprNum rightNum; // filled when right is numeric + std::string rightVarName; // filled when right is a variable + std::string swizzleLeft; + std::string swizzleRight; +}; + +struct CalcMultiPart { + std::string op; + ExprVarName right; // when right is a VarName + std::optional rightVal; // when right is a number + std::string swizzleRight; + int32_t groupStart = 0; + int32_t groupEnd = 0; +}; + +struct CalcMulti { + ExprVarName left; + std::optional leftVal; // when left is a number + std::string swizzleLeft; + std::vector parts; + int32_t groupStart = 0; +}; + +struct CalcFunc { + std::string funcName; + std::vector args; + std::string swizzleRight; +}; + +struct TernaryPart { + std::string left; // variable name + std::string right; // variable name + std::optional rightVal; // value when right is a number + std::string swizzleRight; +}; + +struct CalcCompare { + std::string left; // variable name + std::string op; + std::string right; // variable name + std::optional rightVal; // value when right is a number + std::string swizzleRight; + std::optional ternary; +}; + +// Calc variant — includes all calculation node types +using Calc = std::variant< + CalcNum, + CalcVar, + CalcLR, + CalcMulti, + CalcMultiPart, + CalcFunc, + CalcCompare +>; + +// --- Statement types -------------------------------------------------- + +struct StmtVarDecl { + std::string varName; + std::string varType; + std::string reg; + bool isConst = false; + uint32_t line = 0; +}; + +struct StmtVarDeclMulti { + std::string varType; + std::string reg; + std::vector varNames; + bool isConst = false; + uint32_t line = 0; +}; + +struct StmtVarDeclAssign { + std::string varType; + std::string reg; + std::string varName; + std::unique_ptr calc; + bool isConst = false; + uint32_t line = 0; +}; + +struct StmtVarDeclAlias { + std::string aliasName; + std::string varName; + uint32_t line = 0; +}; + +struct StmtVarUndef { + std::string varName; + uint32_t line = 0; +}; + +struct StmtVarAssignCalc { + std::string varName; + std::string swizzle; + std::string assignType; // "=", "+=", "-=", etc. + std::unique_ptr calc; + uint32_t line = 0; +}; + +struct StmtFuncCall { + std::string func; + std::vector args; + uint32_t line = 0; +}; + +struct StmtLabelDecl { + std::string name; + uint32_t line = 0; +}; + +struct StmtGoto { + std::string label; + uint32_t line = 0; +}; + +struct StmtIf { + CompareExpr compare; + std::unique_ptr blockIf; + std::unique_ptr blockElse; + uint32_t line = 0; +}; + +struct StmtWhile { + CompareExpr compare; + std::unique_ptr block; + uint32_t line = 0; +}; + +struct StmtLoop { + std::optional compare; + std::unique_ptr block; + uint32_t line = 0; +}; + +struct StmtBreak { + uint32_t line = 0; +}; + +struct StmtContinue { + uint32_t line = 0; +}; + +struct StmtExit { + uint32_t line = 0; +}; + +struct StmtAnnotation { + std::string name; + std::string value; + uint32_t line = 0; +}; + +// --- Stmt variant + ScopedBlock (mutually recursive) ------------------ + +// Nested scoped block — carries its own scope +struct StmtScopedBlock { + std::unique_ptr body; + uint32_t line = 0; +}; + +using Stmt = std::variant< + StmtVarDecl, + StmtVarDeclMulti, + StmtVarDeclAssign, + StmtVarDeclAlias, + StmtVarUndef, + StmtVarAssignCalc, + StmtFuncCall, + StmtLabelDecl, + StmtGoto, + StmtIf, + StmtWhile, + StmtLoop, + StmtBreak, + StmtContinue, + StmtExit, + StmtAnnotation, + StmtScopedBlock +>; + +struct ScopedBlock { + std::vector statements; + uint32_t line = 0; +}; + +// --- State variable definition ---------------------------------------- + +struct StateVarDef { + std::string varType; + std::string varName; + bool isExtern = false; + std::vector arraySize; + int64_t align = 0; + std::vector value; +}; + +// --- State section ---------------------------------------------------- + +struct StateSection { + std::string name; // "state", "data", "bss" + std::vector vars; +}; + +// --- Function --------------------------------------------------------- + +struct Function { + std::vector annotations; + FuncType type = FuncType::Function; // function, command, macro + std::optional resultType; // command index + std::string name; + std::vector args; + std::unique_ptr body; // null for extern declarations +}; + +// --- Top-level AST ---------------------------------------------------- + +struct DefineEntry { + std::string name; + std::string value; +}; + +struct Program { + std::vector includes; + std::vector states; + std::vector functions; + std::vector postIncludes; + std::vector defines; +}; + +// --- JSON deserialization --------------------------------------------- + +Program parseJson(const std::string &json); + +} // namespace rspl::ast diff --git a/cpp/src/ast2asm.cpp b/cpp/src/ast2asm.cpp new file mode 100644 index 0000000..bd4db78 --- /dev/null +++ b/cpp/src/ast2asm.cpp @@ -0,0 +1,1078 @@ +#include "ast2asm.h" + +#include "asm.h" +#include "asm_normalize.h" +#include "ast.h" +#include "astCalcNormalizer.h" +#include "builtins.h" +#include "operations/branch.h" +#include "operations/scalar.h" +#include "operations/user_function.h" +#include "operations/vector.h" +#include "registers.h" +#include "state.h" +#include "swizzle.h" +#include "types.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace rspl { + +// --- Macro registry --------------------------------------------------- + +static std::unordered_map macros; + +// --- Forward declarations --------------------------------------------- + +static std::vector +scopedBlockToAsm(const ast::ScopedBlock &block); + +// --- Macro inlining --------------------------------------------------- + +static std::vector +inlineMacroCall(const std::string ¯oName, + const std::vector &args) { + auto it = macros.find(macroName); + if (it == macros.end()) return {}; + + const ast::Function ¯o = *it->second; + if (macro.args.size() != args.size()) { + state.throwError("Macro '" + macroName + "' expects " + + std::to_string(macro.args.size()) + + " arguments, got " + std::to_string(args.size()) + + "!"); + } + + std::vector res; + state.pushScope("", ""); + for (size_t i = 0; i < args.size(); ++i) { + state.declareVarAlias(macro.args[i].name, args[i].value); + } + + auto body = scopedBlockToAsm(*macro.body); + res.insert(res.end(), body.begin(), body.end()); + state.popScope(); + return res; +} + +static const std::string LABEL_CMD_LOOP = "RSPQ_Loop"; + +// --- Type inference for declarations ---------------------------------- + +static TypeClass inferCalcResultType(const ast::Calc &calc, + const std::string &declType) { + return std::visit( + [&](const auto &c) -> TypeClass { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return toTypeClass(declType); + } else if constexpr (std::is_same_v) { + return toTypeClass(declType); + } else if constexpr (std::is_same_v) { + const VarDef *l = state.getVar(c.left.value); + if (!c.rightVarName.empty()) { + const VarDef *r = state.getVar(c.rightVarName); + if (l && r && (isVecType(l->type) || isVecType(r->type))) + return isVecType(l->type) ? l->type : r->type; + } + if (l && isVecType(l->type)) return l->type; + return toTypeClass(declType); + } else if constexpr (std::is_same_v) { + // If the declared type is already a vector type, trust it. + // This matters for mixed-type expressions like vec16 * vec32 + // whose result should be vec32 if the variable is declared as + // vec32. Without this the left operand's type (vec16) wins and + // later type-sensitive operations (e.g. clip()) receive wrong + // types. + if (isVecType(declType)) return toTypeClass(declType); + const VarDef *l = state.getVar(c.left.value); + if (l && isVecType(l->type)) return l->type; + for (const auto &p : c.parts) { + if (!p.right.value.empty()) { + const VarDef *r = state.getVar(p.right.value); + if (r && isVecType(r->type)) return r->type; + } + } + return toTypeClass(declType); + } else { + return toTypeClass(declType); + } + }, + calc); +} + +// --- Forward declaration ---------------------------------------------- + +static std::vector +calcToAsm(const ast::Calc &calc, const VarDef &varRes); + +// --- Group decomposition for CalcMulti --------------------------------- +// Port of JS astCalcNormalizer + astCalcPartsToASM. +// Flattens group markers, applies precedence, constant-folds (via +// partsEval), and decomposes complex expressions into temp variables. + +// Flatten calcMulti parts + group markers into a linear token vector. +static std::vector flattenCalcMulti(const ast::CalcMulti &cm) { + std::vector out; + // Use a special OP-like sentinel for brackets so partsToTree can find them. + FlatElem lparen, rparen; + lparen.kind = FlatElem::OP; + lparen.opStr = "("; + rparen.kind = FlatElem::OP; + rparen.opStr = ")"; + + auto addBrackets = [&](int count, const FlatElem &b) { + for (int i = 0; i < count; ++i) out.push_back(b); + }; + addBrackets(cm.groupStart, lparen); + if (cm.leftVal.has_value()) { + out.push_back({FlatElem::VAL, {}, cm.leftVal.value(), {}, + cm.swizzleLeft, true}); + } else { + out.push_back({FlatElem::VAL, {}, 0, cm.left.value, + cm.swizzleLeft, false}); + } + for (const auto &p : cm.parts) { + out.push_back({FlatElem::OP, p.op}); + addBrackets(p.groupStart, lparen); + if (p.rightVal.has_value()) { + out.push_back({FlatElem::VAL, {}, p.rightVal.value(), {}, + p.swizzleRight, true}); + } else { + out.push_back({FlatElem::VAL, {}, 0, p.right.value, + p.swizzleRight, false}); + } + addBrackets(p.groupEnd, rparen); + } + return out; +} + +// Convert bracket markers "(" / ")" into nested FlatElem vectors. +static void partsToTree(std::vector &parts) { + for (size_t i = 0; i < parts.size();) { + if (parts[i].kind == FlatElem::OP && parts[i].opStr == "(") { + int depth = 1; + size_t start = i; + size_t j = i + 1; + for (; j < parts.size() && depth > 0; ++j) { + if (parts[j].kind == FlatElem::OP && parts[j].opStr == "(") depth++; + else if (parts[j].kind == FlatElem::OP && parts[j].opStr == ")") depth--; + } + // Extract sub-expression between brackets + std::vector sub(parts.begin() + start + 1, + parts.begin() + j - 1); + partsToTree(sub); // recurse into sub-expression + // Replace the bracket group with a single nested FlatElem + parts.erase(parts.begin() + start, parts.begin() + j); + FlatElem nested; + nested.kind = FlatElem::VAL; + nested.varName = NESTED_SENTINEL; + nested.nested = std::move(sub); + nested.isNested = true; + parts.insert(parts.begin() + start, std::move(nested)); + i = start + 1; // continue after the inserted element + } else { + ++i; + } + } +} + +// Forward declaration for mutual recursion. +static void decomposeParts(std::vector &parts, + const VarDef &varRes, + std::vector &out, + int &tmpCounter); + +// Resolve a FlatElem value into a VarDef. For nested sub-expressions, +// recursively decompose into temp variables and return the temp var. +static VarDef resolveFlatVal(FlatElem &elem, + const VarDef &varRes, + std::vector &out, + int &tmpCounter) { + if (elem.isNested) { + // Recursively decompose the nested sub-expression into a temp variable + std::string tmpName = "__tmp_" + std::to_string(tmpCounter++); + state.declareVar(tmpName, toString(varRes.type), + state.allocRegister(toString(varRes.type))); + VarDef tmpVar = state.getRequiredVarCopy(tmpName, "tmp"); + decomposeParts(elem.nested, tmpVar, out, tmpCounter); + return tmpVar; + } + if (elem.isNum) { + VarDef v; + v.value = elem.numVal; + v.type = varRes.type; + return v; + } + VarDef v = state.getRequiredVarCopy(elem.varName, "val"); + v.swizzle = elem.swizzle; + return v; +} + +// Decompose a parts vector into ASM instructions, accumulating into +// `varRes`. Nested sub-expressions are emitted into temp variables. +static void decomposeParts(std::vector &parts, + const VarDef &varRes, + std::vector &out, + int &tmpCounter) { + if (parts.empty()) return; + + // Resolve first value + size_t pos = 0; + VarDef accVar; + bool accIsConst = false; + double accConst = 0; + + if (pos < parts.size() && parts[pos].kind == FlatElem::VAL) { + if (parts[pos].isNested) { + accVar = resolveFlatVal(parts[pos], varRes, out, tmpCounter); + } else if (parts[pos].isNum) { + accIsConst = true; + accConst = parts[pos].numVal; + } else { + accVar = state.getRequiredVarCopy(parts[pos].varName, "left"); + accVar.swizzle = parts[pos].swizzle; + } + ++pos; + } + + if (pos >= parts.size()) { + VarDef finalLeft; + if (accIsConst) { + finalLeft.value = accConst; + finalLeft.type = varRes.type; + } else { + finalLeft = accVar; + } + auto mv = isVecType(varRes.type) ? ops::opMoveVec(varRes, finalLeft) + : ops::opMove(varRes, finalLeft); + out.insert(out.end(), mv.begin(), mv.end()); + return; + } + + bool isFirst = true; + VarDef firstLeft; + if (!accIsConst) firstLeft = accVar; + + while (pos + 1 <= parts.size() && + (pos < parts.size() && parts[pos].kind == FlatElem::OP)) { + std::string op = parts[pos].opStr; + // Skip bracket sentinels (shouldn't appear after partsToTree) + if (op == "(" || op == ")") { ++pos; continue; } + ++pos; + VarDef right; + if (pos < parts.size() && parts[pos].kind == FlatElem::VAL) { + right = resolveFlatVal(parts[pos], varRes, out, tmpCounter); + ++pos; + } + + if (isFirst) { + isFirst = false; + if (accIsConst) { + VarDef cl; + cl.value = accConst; + cl.type = varRes.type; + auto mv = isVecType(varRes.type) ? ops::opMoveVec(varRes, cl) + : ops::opMove(varRes, cl); + out.insert(out.end(), mv.begin(), mv.end()); + accIsConst = false; + // Apply op to varRes + if (!isVecType(varRes.type)) { + if (op == "+") { auto a = ops::opAdd(varRes, varRes, right); out.insert(out.end(), a.begin(), a.end()); } + else if (op == "-") { auto s = ops::opSub(varRes, varRes, right); out.insert(out.end(), s.begin(), s.end()); } + else if (op == "*") { auto m = ops::opMul(varRes, varRes, right); out.insert(out.end(), m.begin(), m.end()); } + } + } else { + // Try to fuse move + first op by calling the appropriate op directly + if (!isVecType(varRes.type)) { + if (op == "+") { + auto a = ops::opAdd(varRes, firstLeft, right); + out.insert(out.end(), a.begin(), a.end()); + } else if (op == "-") { + auto s = ops::opSub(varRes, firstLeft, right); + out.insert(out.end(), s.begin(), s.end()); + } else if (op == "*") { + auto m = ops::opMul(varRes, firstLeft, right); + out.insert(out.end(), m.begin(), m.end()); + } else if (op == "/") { + auto d = ops::opDiv(varRes, firstLeft, right); + out.insert(out.end(), d.begin(), d.end()); + } else if (op == "&") { + auto a = ops::opAnd(varRes, firstLeft, right); + out.insert(out.end(), a.begin(), a.end()); + } else if (op == "|") { + auto o = ops::opOr(varRes, firstLeft, right); + out.insert(out.end(), o.begin(), o.end()); + } else if (op == "^") { + auto x = ops::opXOR(varRes, firstLeft, right); + out.insert(out.end(), x.begin(), x.end()); + } else if (op == "<<") { + auto s = ops::opShiftLeft(varRes, firstLeft, right); + out.insert(out.end(), s.begin(), s.end()); + } else if (op == ">>") { + auto s = ops::opShiftRight(varRes, firstLeft, right, false); + out.insert(out.end(), s.begin(), s.end()); + } else if (op == ">>>") { + auto s = ops::opShiftRight(varRes, firstLeft, right, true); + out.insert(out.end(), s.begin(), s.end()); + } else { + // Unknown op: move then apply + auto mv = ops::opMove(varRes, firstLeft); + out.insert(out.end(), mv.begin(), mv.end()); + } + } else { + // Vec ops + if (op == "+") { + auto a = ops::opAddVec(varRes, firstLeft, right); + out.insert(out.end(), a.begin(), a.end()); + } else if (op == "-") { + auto s = ops::opSubVec(varRes, firstLeft, right); + out.insert(out.end(), s.begin(), s.end()); + } else if (op == "*") { + auto m = ops::opMulVec(varRes, firstLeft, right, true); + out.insert(out.end(), m.begin(), m.end()); + } else if (op == "+*") { + auto m = ops::opMulVec(varRes, firstLeft, right, false); + out.insert(out.end(), m.begin(), m.end()); + } else { + auto mv = ops::opMoveVec(varRes, firstLeft); + out.insert(out.end(), mv.begin(), mv.end()); + } + } + } + } else { + if (!isVecType(varRes.type)) { + if (op == "+") { auto a = ops::opAdd(varRes, varRes, right); out.insert(out.end(), a.begin(), a.end()); } + else if (op == "-") { auto s = ops::opSub(varRes, varRes, right); out.insert(out.end(), s.begin(), s.end()); } + else if (op == "*") { auto m = ops::opMul(varRes, varRes, right); out.insert(out.end(), m.begin(), m.end()); } + else if (op == "/") { auto d = ops::opDiv(varRes, varRes, right); out.insert(out.end(), d.begin(), d.end()); } + else if (op == "&") { auto a = ops::opAnd(varRes, varRes, right); out.insert(out.end(), a.begin(), a.end()); } + else if (op == "|") { auto o = ops::opOr(varRes, varRes, right); out.insert(out.end(), o.begin(), o.end()); } + else if (op == "^") { auto x = ops::opXOR(varRes, varRes, right); out.insert(out.end(), x.begin(), x.end()); } + else if (op == "<<") { auto s = ops::opShiftLeft(varRes, varRes, right); out.insert(out.end(), s.begin(), s.end()); } + else if (op == ">>") { auto s = ops::opShiftRight(varRes, varRes, right, false); out.insert(out.end(), s.begin(), s.end()); } + else if (op == ">>>") { auto s = ops::opShiftRight(varRes, varRes, right, true); out.insert(out.end(), s.begin(), s.end()); } + } + } + accIsConst = false; + accVar = varRes; + } +} + +static std::vector +decomposeCalcMulti(const ast::CalcMulti &cm, const VarDef &varRes) { + // Fast path: single part, no groups, variable left + if (cm.parts.size() == 1 && cm.groupStart == 0 && + cm.parts[0].groupStart == 0 && cm.parts[0].groupEnd == 0 && + !cm.leftVal.has_value()) { + ast::CalcLR lrCalc; + lrCalc.left = cm.left; + lrCalc.swizzleLeft = cm.swizzleLeft; + lrCalc.op = cm.parts[0].op; + if (cm.parts[0].rightVal.has_value()) { + lrCalc.rightNum = ast::ExprNum{cm.parts[0].rightVal.value()}; + } else { + lrCalc.rightVarName = cm.parts[0].right.value; + } + lrCalc.swizzleRight = cm.parts[0].swizzleRight; + return calcToAsm(ast::Calc(lrCalc), varRes); + } + + // Step 1: flatten group markers into bracket tokens + auto parts = flattenCalcMulti(cm); + + // Step 2: convert brackets to nested structure + partsToTree(parts); + + // Step 2.5: apply operator precedence within nested groups + for (auto &e : parts) { + if (e.isNested) applyPrecedence(e.nested); + } + + // Step 3: evaluate constant sub-expressions (delegated to partsEval) + auto evalResult = partsEval(parts); + if (std::holds_alternative(evalResult)) { + // Entire expression folded to a single constant + FlatElem &elem = std::get(evalResult); + VarDef v; + v.value = elem.numVal; + v.type = varRes.type; + return isVecType(varRes.type) ? ops::opMoveVec(varRes, v) + : ops::opMove(varRes, v); + } + // partsEval returned the (possibly modified) parts vector + parts = std::get>(std::move(evalResult)); + + // Step 4: decompose into temp variables + std::vector res; + int tmpCounter = 0; + decomposeParts(parts, varRes, res, tmpCounter); + return res; +} + +// --- Calculation to ASM ----------------------------------------------- + +static std::vector +calcToAsm(const ast::Calc &calc, const VarDef &varRes) { + return std::visit( + [&](const auto &c) -> std::vector { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + VarDef vRight; + vRight.value = c.right.value; + if (isVecType(varRes.type)) { + return ops::opMoveVec(varRes, vRight); + } + return ops::opMove(varRes, vRight); + } + + else if constexpr (std::is_same_v) { + // Check if the variable is actually a label / state memory + // (JS: astNormalize.js lines 161-166) + const auto *memVar = state.getMemVarOrNull(c.right.value); + if (memVar) { + // Convert label reference to %lo(NAME) immediate + VarDef vRight; + vRight.value = 0; + vRight.type = varRes.type; + vRight.reg = "%lo(" + c.right.value + ")"; + if (isVecType(varRes.type)) + return ops::opMoveVec(varRes, vRight); + return ops::opMove(varRes, vRight); + } + VarDef vRight = + state.getRequiredVarCopy(c.right.value, "right"); + vRight.swizzle = c.swizzleRight; + if (c.op == "~") { + if (isVecType(varRes.type)) + return ops::opBitFlipVec(varRes, vRight); + return ops::opBitFlip(varRes, vRight); + } + if (isVecType(varRes.type)) { + return ops::opMoveVec(varRes, vRight); + } + return ops::opMove(varRes, vRight); + } + + else if constexpr (std::is_same_v) { + VarDef vLeft = + state.getRequiredVarCopy(c.left.value, "Left"); + vLeft.swizzle = c.swizzleLeft; + + VarDef vRight; + if (!c.rightVarName.empty()) { + // Check for label / state-memory reference (JS: astNormalize.js:161-166) + const auto *memVar = + state.getMemVarOrNull(c.rightVarName); + if (memVar) { + vRight.value = 0; + vRight.type = varRes.type; + vRight.reg = "%lo(" + c.rightVarName + ")"; + } else { + vRight = state.getRequiredVarCopy(c.rightVarName, "right"); + } + } else { + vRight.value = c.rightNum.value; + vRight.type = varRes.type; + } + vRight.swizzle = c.swizzleRight; + + bool isVec = isVecType(varRes.type); + std::string op = c.op; + + if (!isVec) { + if (!c.swizzleLeft.empty() && !vLeft.reg.empty() && + !isVecType(vLeft.type)) + state.throwError( + "Swizzling not allowed for scalar operations!"); + if (!c.swizzleRight.empty() && !vRight.reg.empty() && + !isVecType(vRight.type)) + state.throwError( + "Swizzling not allowed for scalar operations!"); + } + + if (isVec) { + if (op == "+") return ops::opAddVec(varRes, vLeft, vRight); + if (op == "-") return ops::opSubVec(varRes, vLeft, vRight); + if (op == "*" || op == "+*") + return ops::opMulVec(varRes, vLeft, vRight, op == "*"); + if (op == "&") return ops::opAndVec(varRes, vLeft, vRight); + if (op == "|") return ops::opOrVec(varRes, vLeft, vRight); + if (op == "^") return ops::opXORVec(varRes, vLeft, vRight); + if (op == "<<") + return ops::opShiftLeftVec(varRes, vLeft, vRight); + if (op == ">>") + return ops::opShiftRightVec(varRes, vLeft, vRight, false); + if (op == ">>>") + return ops::opShiftRightVec(varRes, vLeft, vRight, true); + } else { + if (op == "+") return ops::opAdd(varRes, vLeft, vRight); + if (op == "-") return ops::opSub(varRes, vLeft, vRight); + if (op == "*") return ops::opMul(varRes, vLeft, vRight); + if (op == "/") return ops::opDiv(varRes, vLeft, vRight); + if (op == "&") return ops::opAnd(varRes, vLeft, vRight); + if (op == "|") return ops::opOr(varRes, vLeft, vRight); + if (op == "^") return ops::opXOR(varRes, vLeft, vRight); + if (op == "~|") return ops::opNOR(varRes, vLeft, vRight); + if (op == "<<") + return ops::opShiftLeft(varRes, vLeft, vRight); + if (op == ">>") + return ops::opShiftRight(varRes, vLeft, vRight, false); + if (op == ">>>") + return ops::opShiftRight(varRes, vLeft, vRight, true); + } + state.throwError("Unknown operator: " + op); + return {}; + } + + else if constexpr (std::is_same_v) { + return decomposeCalcMulti(c, varRes); + } + + else if constexpr (std::is_same_v) { + if (macros.count(c.funcName)) { + std::vector callArgs; + callArgs.push_back( + {.type = ArgType::Var, .value = varRes.name, .swizzle = ""}); + for (auto &a : c.args) callArgs.push_back(a); + return inlineMacroCall(c.funcName, callArgs); + } + auto *bf = builtins::lookup(c.funcName); + if (!bf) + state.throwError("Unknown builtin: " + c.funcName); + VarDef resCopy = varRes; + return (*bf)(&resCopy, c.args, c.swizzleRight); + } + + else if constexpr (std::is_same_v) { + bool isVec = isVecType(varRes.type); + if (isVec) { + VarDef vLeft = + state.getRequiredVarCopy(c.left, "left"); + VarDef vRight; + if (c.rightVal.has_value()) { + auto pIt = + POW2_SWIZZLE_VAR.find(c.rightVal.value()); + if (pIt == POW2_SWIZZLE_VAR.end()) + state.throwError("Constant must be a power of two!"); + vRight.reg = pIt->second.reg; + vRight.swizzle = pIt->second.swizzle; + vRight.type = TypeClass::Vec16; + } else { + vRight = state.getRequiredVarCopy(c.right, "right"); + vRight.swizzle = c.swizzleRight; + } + const ast::TernaryPart *tp = + c.ternary.has_value() ? &c.ternary.value() : nullptr; + return ops::opCompareVec(varRes, vLeft, vRight, c.op, tp); + } else { + VarDef vLeft = + state.getRequiredVarCopy(c.left, "left"); + VarDef vRight; + if (c.rightVal.has_value()) { + vRight.value = c.rightVal.value(); + vRight.type = varRes.type; + } else { + vRight = state.getRequiredVarCopy(c.right, "right"); + } + return ops::opCompare(varRes, vLeft, vRight, c.op, false); + } + } + + state.throwError("Unknown calculation type"); + return {}; + }, + calc); +} + +// --- Control flow ----------------------------------------------------- + +static std::vector ifToAsm(const ast::StmtIf &st) { + const VarDef *varLeft = + state.getRequiredVar(st.compare.left.value, "left"); + if (reg::isVecReg(varLeft->reg)) + state.throwError("IF-Statements must use scalar-registers!"); + + std::string labelElse = state.generateLabel(); + std::string labelEnd = + st.blockElse ? state.generateLabel() : labelElse; + + std::vector res; + auto branch = ops::opBranch(st.compare, labelElse); + res.insert(res.end(), branch.begin(), branch.end()); + + state.pushScope("", ""); + auto ifBlock = scopedBlockToAsm(*st.blockIf); + res.insert(res.end(), ifBlock.begin(), ifBlock.end()); + if (st.blockElse) { + res.push_back( + asmBranch("beq", {"$zero", "$zero", labelEnd}, labelEnd)); + res.push_back(asmNOP()); + } + state.popScope(); + + if (st.blockElse) { + state.pushScope("", labelElse); + res.push_back(asmLabel(labelElse)); + auto elseBlock = scopedBlockToAsm(*st.blockElse); + res.insert(res.end(), elseBlock.begin(), elseBlock.end()); + state.popScope(); + } + res.push_back(asmLabel(labelEnd)); + return res; +} + +static std::vector whileToAsm(const ast::StmtWhile &st) { + const VarDef *varLeft = + state.getRequiredVar(st.compare.left.value, "left"); + if (reg::isVecReg(varLeft->reg)) + state.throwError("While-Statements must use scalar-registers!"); + + std::string labelStart = state.generateLabel(); + std::string labelEnd = state.generateLabel(); + + std::vector res; + res.push_back(asmLabel(labelStart)); + + auto branch = ops::opBranch(st.compare, labelEnd); + res.insert(res.end(), branch.begin(), branch.end()); + + state.pushScope(labelStart, labelEnd); + auto body = scopedBlockToAsm(*st.block); + res.insert(res.end(), body.begin(), body.end()); + state.popScope(); + + res.push_back(asmOp("j", {labelStart})); + res.push_back(asmNOP()); + res.push_back(asmLabel(labelEnd)); + return res; +} + +static std::vector loopToAsm(const ast::StmtLoop &st) { + std::string labelStart = state.generateLabel(); + std::string labelEnd = state.generateLabel(); + + // loop { body } while(cond) — emit conditional branch at the tail + if (st.compare.has_value()) { + if (st.compare->left.type == ArgType::Num) { + state.throwError( + "Loop-Statements with numeric left-hand-side not implemented!"); + } + const VarDef *varLeft = + state.getRequiredVar(st.compare->left.value, "left"); + if (reg::isVecReg(varLeft->reg)) + state.throwError("Loop-Statements must use scalar-registers!"); + + std::vector res; + res.push_back(asmLabel(labelStart)); + state.pushScope(labelStart, labelEnd); + auto body = scopedBlockToAsm(*st.block); + res.insert(res.end(), body.begin(), body.end()); + auto branchOps = + ops::opBranch(*st.compare, labelStart, /*invert=*/true); + res.insert(res.end(), branchOps.begin(), branchOps.end()); + state.popScope(); + res.push_back(asmLabel(labelEnd)); + return res; + } + + // Infinite loop: j back to start + std::vector res; + res.push_back(asmLabel(labelStart)); + state.pushScope(labelStart, labelEnd); + auto body = scopedBlockToAsm(*st.block); + res.insert(res.end(), body.begin(), body.end()); + state.popScope(); + res.push_back(asmOp("j", {labelStart})); + res.push_back(asmNOP()); + res.push_back(asmLabel(labelEnd)); + return res; +} + +// --- Statement dispatch ------------------------------------------------ + +// Pre-scan scoped blocks for label declarations and register them as +// memory variables. Ported from JS astNormalize.js lines 19-28. +static void predeclareLabels(const ast::ScopedBlock &block) { + for (const auto &stmt : block.statements) { + if (std::holds_alternative(stmt)) { + auto &ld = std::get(stmt); + state.declareMemVar(ld.name, "u16", 1); + } else if (auto *sb = std::get_if(&stmt)) { + predeclareLabels(*sb->body); + } else if (auto *si = std::get_if(&stmt)) { + if (si->blockIf) predeclareLabels(*si->blockIf); + if (si->blockElse) predeclareLabels(*si->blockElse); + } else if (auto *sw = std::get_if(&stmt)) { + if (sw->block) predeclareLabels(*sw->block); + } else if (auto *sl = std::get_if(&stmt)) { + if (sl->block) predeclareLabels(*sl->block); + } + } +} + +static std::vector +scopedBlockToAsm(const ast::ScopedBlock &block) { + state.line = block.line; + // Pre-scan labels so forward references resolve (JS: astNormalize.js:19-28) + predeclareLabels(block); + + std::vector res; + + for (const auto &stmt : block.statements) { + // Update state.line from the statement's line number (for debug info) + std::visit([&](const auto &s) { state.line = s.line; }, stmt); + + std::visit( + [&](const auto &s) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + std::string reg = s.reg.empty() + ? state.allocRegister(s.varType) + : s.reg; + state.declareVar(s.varName, s.varType, reg, + s.isConst); + } + + else if constexpr (std::is_same_v) { + for (size_t i = 0; i < s.varNames.size(); ++i) { + int step = isTwoRegType(s.varType) ? 2 : 1; + int offset = static_cast(i) * step; + std::string reg = s.reg.empty() + ? state.allocRegister(s.varType) + : reg::nextReg(s.reg, offset) + ? *reg::nextReg(s.reg, offset) + : s.reg; + state.declareVar(s.varNames[i], s.varType, reg, + s.isConst); + } + } + + else if constexpr (std::is_same_v) { + std::string baseName = + s.varName.substr(0, s.varName.find(':')); + std::string effectiveType = s.varType; + if (s.calc) { + effectiveType = toString(inferCalcResultType(*s.calc, s.varType)); + } + state.declareVar(baseName, effectiveType, + s.reg.empty() + ? state.allocRegister(effectiveType) + : s.reg, + s.isConst); + if (s.calc) { + VarDef vr = state.getRequiredVarCopy( + s.varName, "result"); + auto calcAsm = calcToAsm(*s.calc, vr); + res.insert(res.end(), calcAsm.begin(), + calcAsm.end()); + state.markVarModified(baseName); + } + } + + else if constexpr (std::is_same_v) { + state.declareVarAlias(s.aliasName, s.varName); + } + + else if constexpr (std::is_same_v) { + state.undefVar(s.varName); + } + + else if constexpr (std::is_same_v) { + bool handledAsFuncCall = false; + if (auto *cf = std::get_if(s.calc.get())) { + if (!builtins::lookup(cf->funcName)) { + std::vector callArgs; + callArgs.push_back( + {.type = ArgType::Var, .value = s.varName, .swizzle = s.swizzle}); + for (auto &a : cf->args) callArgs.push_back(a); + if (macros.count(cf->funcName)) { + auto inlineRes = inlineMacroCall(cf->funcName, callArgs); + res.insert(res.end(), inlineRes.begin(), inlineRes.end()); + } else { + auto callRes = ops::callUserFunction(cf->funcName, callArgs); + res.insert(res.end(), callRes.begin(), callRes.end()); + } + handledAsFuncCall = true; + } + } + if (!handledAsFuncCall) { + VarDef vr = + state.getRequiredVarCopy(s.varName, "result"); + vr.swizzle = s.swizzle; + + if (vr.isConst && vr.modifyCount > 0) { + state.throwError("Cannot assign to constant variable!"); + } + state.markVarModified(s.varName); + + std::string op = s.assignType; + if (op.size() >= 1 && op != "=") { + std::string baseOp = op.substr(0, op.size() - 1); + if (auto *cn = std::get_if(s.calc.get())) { + ast::CalcLR lrCalc; + lrCalc.left = ast::ExprVarName{s.varName}; + lrCalc.op = baseOp; + lrCalc.rightNum = cn->right; + auto calcAsm = calcToAsm(ast::Calc(lrCalc), vr); + res.insert(res.end(), calcAsm.begin(), calcAsm.end()); + } else if (auto *cv = std::get_if(s.calc.get())) { + ast::CalcLR lrCalc; + lrCalc.left = ast::ExprVarName{s.varName}; + lrCalc.swizzleLeft = s.swizzle; + lrCalc.op = baseOp; + lrCalc.rightVarName = cv->right.value; + lrCalc.swizzleRight = cv->swizzleRight; + auto calcAsm = calcToAsm(ast::Calc(lrCalc), vr); + res.insert(res.end(), calcAsm.begin(), calcAsm.end()); + } else if (auto *cm = std::get_if(s.calc.get())) { + // Wrap compound assign: a += expr -> a = a + (expr) + // Structure: left=a, part0 opens group for expr. + ast::CalcMulti wrapped; + wrapped.left = ast::ExprVarName{s.varName}; + wrapped.swizzleLeft = s.swizzle; + // First part: a + (expr) - open a bracket for expr + ast::CalcMultiPart firstPart; + firstPart.op = baseOp; + if (cm->leftVal.has_value()) { + firstPart.rightVal = cm->leftVal; + } else { + firstPart.right = cm->left; + } + firstPart.swizzleRight = cm->swizzleLeft; + firstPart.groupStart = 1 + cm->groupStart; // open sub-expr + firstPart.groupEnd = 0; + wrapped.parts.push_back(std::move(firstPart)); + // Copy original parts, closing the bracket at the end + for (auto &p : cm->parts) { + wrapped.parts.push_back(p); + } + wrapped.parts.back().groupEnd += 1; // close sub-expr + auto calcAsm = calcToAsm(ast::Calc(wrapped), vr); + res.insert(res.end(), calcAsm.begin(), calcAsm.end()); + } else { + auto calcAsm = calcToAsm(*s.calc, vr); + res.insert(res.end(), calcAsm.begin(), calcAsm.end()); + } + } else { + auto calcAsm = calcToAsm(*s.calc, vr); + res.insert(res.end(), calcAsm.begin(), calcAsm.end()); + } + } + } + + else if constexpr (std::is_same_v) { + if (macros.count(s.func)) { + auto inlineRes = inlineMacroCall(s.func, s.args); + res.insert(res.end(), inlineRes.begin(), inlineRes.end()); + } else { + auto *bf = builtins::lookup(s.func); + if (bf) { + auto callRes = (*bf)(nullptr, s.args, ""); + res.insert(res.end(), callRes.begin(), callRes.end()); + } else { + auto callRes = ops::callUserFunction(s.func, s.args); + res.insert(res.end(), callRes.begin(), callRes.end()); + } + } + } + + else if constexpr (std::is_same_v) { + res.push_back(asmLabel(s.name)); + } + + else if constexpr (std::is_same_v) { + res.push_back(asmOp("j", {s.label})); + res.push_back(asmNOP()); + } + + else if constexpr (std::is_same_v) { + auto ifRes = ifToAsm(s); + res.insert(res.end(), ifRes.begin(), ifRes.end()); + } + + else if constexpr (std::is_same_v) { + auto wRes = whileToAsm(s); + res.insert(res.end(), wRes.begin(), wRes.end()); + } + + else if constexpr (std::is_same_v) { + auto lRes = loopToAsm(s); + res.insert(res.end(), lRes.begin(), lRes.end()); + } + + else if constexpr (std::is_same_v) { + const Scope &scope = state.getScope(); + if (!scope.labelEnd.empty()) { + res.push_back(asmOp("j", {scope.labelEnd})); + res.push_back(asmNOP()); + } + } + + else if constexpr (std::is_same_v) { + const Scope &scope = state.getScope(); + if (!scope.labelStart.empty()) { + res.push_back(asmOp("j", {scope.labelStart})); + res.push_back(asmNOP()); + } + } + + else if constexpr (std::is_same_v) { + res.push_back(asmOp("j", {LABEL_CMD_LOOP})); + res.push_back(asmNOP()); + } + + else if constexpr (std::is_same_v) { + state.addAnnotation(s.name, s.value); + } + + else if constexpr (std::is_same_v) { + state.pushScope("", ""); + auto body = scopedBlockToAsm(*s.body); + res.insert(res.end(), body.begin(), body.end()); + state.popScope(); + } + }, + stmt); + + // Clear per-statement annotations (matching JS ast2asm.js:394-395) + // Annotation statements themselves are exempt so their annotations + // apply to the next real statement. + if (!std::holds_alternative(stmt)) + state.clearAnnotations(); + } + + return res; +} + +// --- ast2asm main ----------------------------------------------------- + +std::vector ast2asm(const ast::Program &ast) { + std::vector result; + state.reset(); + + // Register macros + macros.clear(); + for (const auto &fn : ast.functions) { + if (fn.type == FuncType::Macro) { + macros[fn.name] = &fn; + } + } + + // Pre-declare memory variables from state/data/bss sections + for (const auto &sec : ast.states) { + for (const auto &sv : sec.vars) { + int64_t arraySize = 1; + for (auto dim : sv.arraySize) + arraySize *= dim; + if (arraySize < 1) arraySize = 1; + state.declareMemVar(sv.varName, sv.varType, + static_cast(arraySize)); + } + } + + // Pre-declare all functions so they can reference each other + for (const auto &fn : ast.functions) { + if (fn.type == FuncType::Function || fn.type == FuncType::Command) { + bool isRelative = false; + for (const auto &ann : fn.annotations) { + if (ann.name == "Relative") isRelative = true; + } + state.declareFunction(fn.name, fn.args, isRelative); + } + } + + for (const auto &fn : ast.functions) { + if (fn.type == FuncType::Macro) continue; // already registered + if (!fn.body) continue; // forward declaration only — no body to generate + + // argSize in bytes, matching JS getArgSize() = max(args.length * 4, 4) + int byteArgSize = + std::max(static_cast(fn.args.size()) * 4, 4); + state.enterFunction(fn.name, toString(fn.type), + fn.type == FuncType::Command ? byteArgSize + : fn.resultType.value_or(0)); + + bool isCommand = (fn.type == FuncType::Command); + // Built-in registers (ZERO, VZERO, RA, etc.) are already + // declared by enterFunction(). + + // Declare function arguments + int argSize = 0; + static const char *argRegs[] = {reg::Reg::A0, reg::Reg::A1, + reg::Reg::A2, reg::Reg::A3}; + for (const auto &arg : fn.args) { + std::string reg; + if (!arg.reg.empty()) { + reg = arg.reg; + } else if (argSize < 4) { + reg = argRegs[argSize]; + } else { + reg = state.allocRegister(toString(arg.type)); + } + state.declareVar(arg.name, toString(arg.type), reg); + argSize++; + } + + std::vector funcAsm; + auto body = scopedBlockToAsm(*fn.body); + funcAsm.insert(funcAsm.end(), body.begin(), body.end()); + + // Advance past the closing brace (matching JS ast2asm.js:443) + ++state.line; + + // Check @NoReturn annotation (matching JS ast2asm.js:445) + bool needsReturn = true; + for (const auto &ann : fn.annotations) { + if (ann.name == "NoReturn") needsReturn = false; + } + if (needsReturn) { + if (isCommand) { + funcAsm.push_back(asmOp("j", {LABEL_CMD_LOOP})); + funcAsm.push_back(asmNOP()); + } else { + funcAsm.push_back(asmOp("jr", {reg::Reg::RA})); + funcAsm.push_back(asmNOP()); + } + } + + AsmFunc af; + af.name = fn.name; + af.type = fn.type; + af.asm_ = std::move(funcAsm); + af.argSize = argSize; + af.resultType = fn.resultType.value_or(0); + for (const auto &ann : fn.annotations) { + af.annotations.push_back({ann.name, ann.value}); + } + + normalizeASM(af); + + result.push_back(std::move(af)); + state.leaveFunction(); + } + + return result; +} + +} // namespace rspl diff --git a/cpp/src/ast2asm.h b/cpp/src/ast2asm.h new file mode 100644 index 0000000..385a749 --- /dev/null +++ b/cpp/src/ast2asm.h @@ -0,0 +1,13 @@ +#pragma once + +#include "asm.h" +#include "ast.h" + +#include + +namespace rspl { + +/// Convert a parsed AST program into per-function ASM. +std::vector ast2asm(const ast::Program &prog); + +} // namespace rspl diff --git a/cpp/src/astCalcNormalizer.cpp b/cpp/src/astCalcNormalizer.cpp new file mode 100644 index 0000000..25e0bd4 --- /dev/null +++ b/cpp/src/astCalcNormalizer.cpp @@ -0,0 +1,140 @@ +#include "astCalcNormalizer.h" + +#include +#include +#include + +namespace rspl { + +void applyPrecedence(std::vector &parts, int level) { + static const std::vector> precedence = { + {"*", "/"}, + {"+", "-"}, + {"<<", ">>", ">>>"}, + {"&"}, + {"^"}, + {"|"}, + }; + + for (const auto &ops : precedence) { + int idx = -1; + for (size_t i = 0; i <= parts.size(); i++) { + // sentinel for past-the-end (JS: parts[i] || " ") + if (i == parts.size()) { + if (idx >= 0) { + if (i != parts.size() - 1) { + std::vector sub(parts.begin() + idx, + parts.begin() + i); + FlatElem nested; + nested.kind = FlatElem::VAL; + nested.varName = NESTED_SENTINEL; + nested.nested = std::move(sub); + nested.isNested = true; + parts.erase(parts.begin() + idx, parts.begin() + i); + parts.insert(parts.begin() + idx, std::move(nested)); + i = idx + 1; + } + idx = -1; + } + break; + } + + FlatElem &part = parts[i]; + + // non-string in JS → non-OP in C++ + if (part.kind != FlatElem::OP) { + if (part.isNested && level == 0) + applyPrecedence(part.nested, level + 1); + continue; + } + + bool isPrecOp = + std::find(ops.begin(), ops.end(), part.opStr) != ops.end(); + if (idx == -1 && isPrecOp) { + idx = static_cast(i) - 1; + } + if (idx >= 0 && !isPrecOp) { + if (i != parts.size() - 1) { + std::vector sub(parts.begin() + idx, + parts.begin() + i); + FlatElem nested; + nested.kind = FlatElem::VAL; + nested.varName = NESTED_SENTINEL; + nested.nested = std::move(sub); + nested.isNested = true; + parts.erase(parts.begin() + idx, parts.begin() + i); + parts.insert(parts.begin() + idx, std::move(nested)); + i = idx + 1; + } + idx = -1; + } + } + } +} + +PartsResult partsEval(std::vector &parts, int level) { + for (size_t i = 0; i < parts.size(); i++) { + if (parts[i].isNested) { + auto nestedResult = partsEval(parts[i].nested, level + 1); + if (std::holds_alternative(nestedResult)) { + // bracket was completely evaluated into single value + parts[i] = std::get(std::move(nestedResult)); + i = static_cast(-1); // restart + } else { + parts[i].nested = + std::get>(std::move(nestedResult)); + } + } else if (parts[i].kind == FlatElem::VAL) { + // if both sides are (unswizzled) numbers, we can evaluate them + if (i + 2 >= parts.size()) continue; + if (parts[i + 1].kind != FlatElem::OP) continue; + if (parts[i + 2].kind != FlatElem::VAL) continue; + if (!parts[i].swizzle.empty() || !parts[i + 2].swizzle.empty()) + continue; + if (!parts[i].isNum || !parts[i + 2].isNum) continue; + + double valueL = parts[i].numVal; + double valueR = parts[i + 2].numVal; + std::string op = parts[i + 1].opStr; + + double newVal; + bool ok = true; + if (op == "+") newVal = valueL + valueR; + else if (op == "-") newVal = valueL - valueR; + else if (op == "*") newVal = valueL * valueR; + else if (op == "/") newVal = valueL / valueR; + else if (op == "<<") + newVal = static_cast(valueL) << static_cast(valueR); + else if (op == ">>") + newVal = static_cast(valueL) >> static_cast(valueR); + else if (op == ">>>") { + uint32_t u = static_cast(valueL); + u >>= static_cast(valueR); + newVal = static_cast(u); + } else if (op == "&") + newVal = static_cast(valueL) & static_cast(valueR); + else if (op == "^") + newVal = static_cast(valueL) ^ static_cast(valueR); + else if (op == "|") + newVal = static_cast(valueL) | static_cast(valueR); + else + ok = false; + + if (ok) { + // replace the 3 parts with the computed value + FlatElem folded; + folded.kind = FlatElem::VAL; + folded.isNum = true; + folded.numVal = newVal; + parts.erase(parts.begin() + i, parts.begin() + i + 3); + parts.insert(parts.begin() + i, std::move(folded)); + i--; // re-check this position + } + } + } + + if (parts.size() == 1) return parts[0]; + return parts; +} + +} // namespace rspl diff --git a/cpp/src/astCalcNormalizer.h b/cpp/src/astCalcNormalizer.h new file mode 100644 index 0000000..d70e612 --- /dev/null +++ b/cpp/src/astCalcNormalizer.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace rspl { + +// A single element in the expression tree. Mirrors the JS representation +// where each element is either an operator string, a value object, or a +// nested sub-array. +struct FlatElem { + enum Kind { VAL, OP }; + Kind kind; + std::string opStr; // for OP ("+", "-", "*", "<<", "&", etc.) + double numVal = 0; // for VAL when numeric + std::string varName; // for VAL when variable + std::string swizzle; // swizzle on the value + bool isNum = false; // VAL is numeric + std::vector nested; // nested sub-expression + bool isNested = false; // true when wrapping a sub-expression +}; + +inline const char NESTED_SENTINEL[] = "\x01"; + +using PartsResult = std::variant>; + +// Evaluates constant sub-expressions at compile time. +// Returns a single FlatElem if the entire expression folded to a constant, +// or the (possibly modified) parts vector. +PartsResult partsEval(std::vector &parts, int level = 0); + +// Applies order-of-operations by nesting higher-precedence operators +// into sub-arrays. E.g. [a, "+", b, "*", c] -> [a, "+", [b, "*", c]]. +void applyPrecedence(std::vector &parts, int level = 0); + +} // namespace rspl diff --git a/cpp/src/builtins.cpp b/cpp/src/builtins.cpp new file mode 100644 index 0000000..c0aa047 --- /dev/null +++ b/cpp/src/builtins.cpp @@ -0,0 +1,1282 @@ +#include "builtins.h" +#include "asm.h" +#include "operations/branch.h" +#include "operations/scalar.h" +#include "operations/user_function.h" +#include "operations/vector.h" +#include "registers.h" +#include "state.h" +#include "swizzle.h" +#include "types.h" + +#include + +namespace rspl::builtins { + +using namespace rspl::ops; + +static const std::string LABEL_ASSERT = "assertion_failed"; + +namespace SP_STATUS { + constexpr int64_t HALTED = 1<<0; + constexpr int64_t BROKE = 1<<1; + constexpr int64_t DMA_BUSY = 1<<2; + constexpr int64_t DMA_FULL = 1<<3; + constexpr int64_t IO_FULL = 1<<4; + constexpr int64_t SSTEP = 1<<5; + constexpr int64_t INTR_BREAK = 1<<6; + constexpr int64_t SIG0 = 1<<7; + constexpr int64_t SIG1 = 1<<8; + constexpr int64_t SIG2 = 1<<9; + constexpr int64_t SIG3 = 1<<10; + constexpr int64_t SIG4 = 1<<11; + constexpr int64_t SIG5 = 1<<12; + constexpr int64_t SIG6 = 1<<13; + constexpr int64_t SIG7 = 1<<14; +} + +static const std::unordered_map DMA_FLAGS = { + {"DMA_IN_ASYNC", 0x00000000}, + {"DMA_OUT_ASYNC", 0xFFFF8000}, + {"DMA_IN", 0x00000000 | SP_STATUS::DMA_BUSY | SP_STATUS::DMA_FULL}, + {"DMA_OUT", 0xFFFF8000 | SP_STATUS::DMA_BUSY | SP_STATUS::DMA_FULL}, +}; + +// --- Helpers ---------------------------------------------------------- + +static void assertArgsNoSwizzle(const std::vector &args, int offset = 0) { + for (size_t i = offset; i < args.size(); ++i) { + if (!args[i].swizzle.empty()) { + state.throwError(offset > 0 + ? "Only the first " + std::to_string(offset) + " argument(s) can use swizzling!" + : "Arguments with swizzle not allowed in this function!"); + } + } +} + +static VarDef resolveArg(const ast::FuncArg &arg, const std::string &ctx) +{ + if (arg.type == ArgType::Num) { + VarDef v; + v.value = std::stoll(arg.value); + return v; + } + // Try register variable first, fall back to memory variable + if (state.varExists(arg.value)) { + VarDef v = state.getRequiredVarCopy(arg.value, ctx); + v.swizzle = arg.swizzle; + return v; + } + auto memVar = state.getRequiredVarOrMem(arg.value, ctx); + VarDef v; + v.type = toTypeClass(memVar.type); + v.reg = ""; // no register — it's a memory label + v.swizzle = arg.swizzle; + v.value = 0; + v.name = memVar.name; // store name for label reference + return v; +} + +// --- Builtin implementations ------------------------------------------ + +// load() +static std::vector +b_load(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + assertArgsNoSwizzle(args); + if (!varRes) + state.throwError("Builtin load() needs a left-side"); + if (args.empty()) + state.throwError("Builtin load() requires at least 1 argument!"); + + auto argVar = state.getRequiredVarOrMem(args[0].value, "arg0"); + VarOrMem argOffset; + if (args.size() >= 2 && args[1].type == ArgType::Num) + argOffset.reg = args[1].value; + else if (args.size() >= 2) + argOffset = state.getRequiredVarOrMem(args[1].value, "arg1"); + + if (!argVar.reg.empty() && isVecType(argVar.type)) + state.throwError( + "Builtin load() requires first argument to be a scalar!"); + + if (reg::isVecReg(varRes->reg)) { + return opLoadVec(*varRes, argVar, argOffset, swizzle); + } + return opLoad(*varRes, argVar, argOffset); +} + +// load variant with extra flags +static std::vector +b_load_ex(const VarDef *varRes, const std::vector &args, + const std::string &swizzle, + bool isPackedByte, bool isSigned, bool isUnaligned) { + assertArgsNoSwizzle(args); + if (!varRes) state.throwError("Builtin load() needs a left-side"); + if (args.empty()) state.throwError("Builtin load() requires at least 1 argument!"); + auto argVar = state.getRequiredVarOrMem(args[0].value, "arg0"); + VarOrMem argOffset; + if (args.size() >= 2 && args[1].type == ArgType::Num) + argOffset.reg = args[1].value; + else if (args.size() >= 2) + argOffset = state.getRequiredVarOrMem(args[1].value, "arg1"); + if (!argVar.reg.empty() && isVecType(argVar.type)) + state.throwError("Builtin load() requires first argument to be a scalar!"); + if (reg::isVecReg(varRes->reg)) + return opLoadVec(*varRes, argVar, argOffset, swizzle, isPackedByte, isSigned, isUnaligned); + return opLoad(*varRes, argVar, argOffset); +} + +// store() +static std::vector +b_store(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + assertArgsNoSwizzle(args, 1); + if (varRes) + state.throwError("Builtin store() cannot have a left side!"); + if (args.empty() || args[0].type != ArgType::Var) + state.throwError("Builtin store() requires first argument to be a variable!"); + + VarDef varSrc = resolveArg(args[0], "arg0"); + varSrc.swizzle = args[0].swizzle; + + std::vector offsets; + for (size_t i = 1; i < args.size(); ++i) { + if (args[i].type == ArgType::Num) { + VarOrMem v; + v.reg = args[i].value; + offsets.push_back(v); + } else { + offsets.push_back(state.getRequiredVarOrMem(args[i].value, "store_offset")); + } + } + + if (reg::isVecReg(varSrc.reg)) { + return opStoreVec(varSrc, offsets); + } + if (!varSrc.swizzle.empty()) + state.throwError("Scalar variables cannot use swizzling!"); + return opStore(varSrc, offsets); +} + +// store variant with extra flags +static std::vector +b_store_ex(const VarDef *varRes, const std::vector &args, + const std::string &swizzle, + bool isPackedByte, bool isSigned, bool isUnaligned) { + assertArgsNoSwizzle(args, 1); + if (varRes) state.throwError("Builtin store() cannot have a left side!"); + if (args.empty() || args[0].type != ArgType::Var) + state.throwError("Builtin store() requires first argument to be a variable!"); + VarDef varSrc = resolveArg(args[0], "arg0"); + varSrc.swizzle = args[0].swizzle; + std::vector offsets; + for (size_t i = 1; i < args.size(); ++i) { + if (args[i].type == ArgType::Num) { + VarOrMem v; v.reg = args[i].value; offsets.push_back(v); + } else { + offsets.push_back(state.getRequiredVarOrMem(args[i].value, "store_offset")); + } + } + if (reg::isVecReg(varSrc.reg)) + return opStoreVec(varSrc, offsets, isPackedByte, isSigned, isUnaligned); + if (!varSrc.swizzle.empty()) + state.throwError("Scalar variables cannot use swizzling!"); + return opStore(varSrc, offsets); +} + +// abs() +static std::vector +b_abs(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin abs() cannot use swizzle!"); + if (args.size() != 1) + state.throwError("Builtin abs() requires exactly one argument!"); + VarDef varArg = resolveArg(args[0], "arg0"); + if (!isVecType(varArg.type)) + state.throwError("Builtin abs() requires a vector argument!"); + if (!varRes) + state.throwError("Builtin abs() needs a left-side"); + + return {asmOp("vabs", {varRes->reg, varArg.reg, varArg.reg})}; +} + +// clear_vcc() +static std::vector +b_clear_vcc(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin clear_vcc() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin clear_vcc() must not have a left side!"); + if (!args.empty()) + state.throwError( + "Builtin clear_vcc() requires no arguments!"); + return {asmOp("vsubc", {reg::Reg::VTEMP0, reg::Reg::VZERO, + reg::Reg::VZERO})}; +} + +// get_vcc() +static std::vector +b_get_vcc(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin get_vcc() cannot use swizzle!"); + if (!varRes) + state.throwError("Builtin get_vcc() must have a left side!"); + if (!args.empty()) + state.throwError( + "Builtin get_vcc() requires no arguments!"); + if (reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin get_vcc() must be assigned to a scalar variable!"); + return {asmOp("cfc2", {varRes->reg, reg::RegCop2::VCC})}; +} + +// clip(pos, planeW) -> u32 (clipping flags in VCC) +static std::vector +b_clip(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError( + "To use swizzle in clip(), apply it to the second argument instead!"); + if (!varRes) + state.throwError("Builtin clip() must have a left side!"); + if (reg::isVecReg(varRes->reg)) + state.throwError("Builtin clip() must be assigned to a scalar variable!"); + if (args.size() != 2) + state.throwError("Builtin clip() requires exactly two arguments!"); + + VarDef varArg0 = resolveArg(args[0], "arg0"); + VarDef varArg1 = resolveArg(args[1], "arg1"); + + std::string swizzleRight; + if (!args[1].swizzle.empty()) { + auto sit = SWIZZLE_MAP.find(args[1].swizzle); + if (sit != SWIZZLE_MAP.end()) + swizzleRight = sit->second; + } + + if (!isVecType(varArg0.type)) + state.throwError("Builtin clip() requires first argument to be a vector!"); + if (!isVecType(varArg1.type)) + state.throwError("Builtin clip() requires second argument to be a vector!"); + bool is32BitA = (varArg0.type == TypeClass::Vec32); + bool is32BitB = (varArg1.type == TypeClass::Vec32); + if (is32BitA != is32BitB) + state.throwError( + "Builtin clip() requires both arguments to be of the same type!"); + + if (is32BitA) { + const std::string *nextReg0 = reg::nextVecReg(varArg0.reg); + const std::string *nextReg1 = reg::nextVecReg(varArg1.reg); + return { + asmOp("vch", + {reg::Reg::VTEMP0, varArg0.reg, varArg1.reg + swizzleRight}), + asmOp("vcl", + {reg::Reg::VTEMP0, *nextReg0, *nextReg1 + swizzleRight}), + asmOp("cfc2", {varRes->reg, reg::RegCop2::VCC}), + }; + } + return { + asmOp("vch", + {reg::Reg::VTEMP0, varArg0.reg, varArg1.reg + swizzleRight}), + asmOp("cfc2", {varRes->reg, reg::RegCop2::VCC}), + }; +} + +// set_vcc() +static std::vector +b_set_vcc(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin set_vcc() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin set_vcc() must not have a left side!"); + if (args.size() != 1) + state.throwError("Builtin set_vcc() requires 1 scalar argument!"); + + std::string reg = reg::Reg::AT; + std::vector res; + if (args[0].type == ArgType::Num) { + auto load = loadImmediate(reg, args[0].value); + res.insert(res.end(), load.begin(), load.end()); + } else { + VarDef varArg = resolveArg(args[0], "arg0"); + if (isVecType(varArg.type)) + state.throwError( + "Builtin set_vcc() requires a scalar argument!"); + reg = varArg.reg; + } + res.push_back(asmOp("ctc2", {reg, reg::RegCop2::VCC})); + return res; +} + +// get_acc() -> vec32 +static std::vector +b_get_acc(const VarDef *varRes, const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin get_acc() cannot use swizzle!"); + if (!varRes) + state.throwError("Builtin get_acc() must have a left side!"); + if (!args.empty()) + state.throwError( + "Builtin get_acc() requires no arguments!"); + if (!reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin get_acc() must be assigned to a vector variable!"); + if (varRes->type != TypeClass::Vec32) + state.throwError( + "Builtin get_acc() must be assigned to a vec32 variable!\n" + "Use get_acc_high/mid/low."); + return {asmOp("vsar", {varRes->reg, reg::RegCop2::ACC_HI}), + asmOp("vsar", + {*reg::nextVecReg(varRes->reg), reg::RegCop2::ACC_MD})}; +} + +// get_acc_high/mid/low +static std::vector +b_get_acc_part(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle, const std::string &part) { + if (!swizzle.empty()) + state.throwError("Builtin get_acc_*() cannot use swizzle!"); + if (!varRes) + state.throwError("Builtin get_acc_*() must have a left side!"); + if (!args.empty()) + state.throwError( + "Builtin get_acc_*() requires no arguments!"); + if (!reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin get_acc_*() must be assigned to a vector variable!"); + if (varRes->type != TypeClass::Vec16) + state.throwError( + "Builtin get_acc_*() must be assigned to a vec16 variable!\n" + "Use get_acc()."); + return {asmOp("vsar", {varRes->reg, part})}; +} + +// mfc0 reads (generic) +static std::vector +b_mfc0_read(const VarDef *varRes, const std::string &rdpReg, + const std::string &name) { + if (!varRes) + state.throwError("Builtin " + name + "() must have a left side!"); + if (reg::isVecReg(varRes->reg)) + state.throwError("Builtin " + name + + "() must be assigned to a scalar variable!"); + return {asmOp("mfc0", {varRes->reg, rdpReg})}; +} + +// mtc0 writes (generic) +static std::vector +b_mtc0_write(const VarDef *varRes, + const std::vector &args, + const std::string &rdpReg) { + if (varRes) + state.throwError("Builtin must not have a left side!"); + if (args.size() != 1) + state.throwError("Builtin requires 1 scalar argument!"); + std::string reg = reg::Reg::AT; + std::vector res; + if (args[0].type == ArgType::Num) { + auto load = loadImmediate(reg, args[0].value); + res.insert(res.end(), load.begin(), load.end()); + } else { + VarDef varArg = resolveArg(args[0], "arg0"); + if (isVecType(varArg.type)) + state.throwError("Builtin requires a scalar argument!"); + reg = varArg.reg; + } + res.push_back(asmOp("mtc0", {reg, rdpReg})); + return res; +} + +// get_cmd_address() +static std::vector +b_get_cmd_address(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError( + "Builtin get_cmd_address() cannot use swizzle!"); + if (!varRes) + state.throwError( + "Builtin get_cmd_address() must have a left side!"); + if (args.size() > 1) + state.throwError( + "Builtin get_cmd_address() requires zero or one argument!"); + if (args.size() == 1 && args[0].type != ArgType::Num) + state.throwError( + "Builtin get_cmd_address() requires the argument to be a number!"); + if (reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin get_cmd_address() must be assigned to a scalar variable!"); + int offset = args.empty() ? 0 : std::stoi(args[0].value); + offset -= state.argSize; + // Match JS format: "NAME ${sign} ${abs(offset)}" where sign is empty for negative + std::string offStr = "%lo(RSPQ_DMEM_BUFFER)"; + offStr += " " + std::string(offset < 0 ? "" : "+") + " " + + std::to_string(offset); + return {asmOp("addiu", {varRes->reg, reg::Reg::GP, offStr})}; +} + +// load_arg() +static std::vector +b_load_arg(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin load_arg() cannot use swizzle!"); + if (!varRes) + state.throwError("Builtin load_arg() must have a left side!"); + if (args.size() > 1) + state.throwError( + "Builtin load_arg() requires zero or one argument!"); + if (args.size() == 1 && args[0].type != ArgType::Num) + state.throwError( + "Builtin load_arg() requires the argument to be a number!"); + if (reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin load_arg() must be assigned to a scalar variable!"); + int offset = args.empty() ? 0 : std::stoi(args[0].value); + offset -= state.argSize; + VarOrMem loc; + loc.reg = reg::Reg::GP; + VarOrMem off; + // Match JS format: "%lo(NAME ${sign} ${abs(offset)})" + // where sign is empty for negative + off.reg = std::string("%lo(RSPQ_DMEM_BUFFER") + " " + + std::string(offset < 0 ? "" : "+") + " " + + std::to_string(offset) + ")"; + return opLoad(*varRes, loc, off); +} + +// dma_in / dma_out / dma_in_async / dma_out_async +static std::vector +b_dma(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle, const std::string &builtinName, + const std::string &dmaName) { + assertArgsNoSwizzle(args); + if (!swizzle.empty()) + state.throwError("Builtin " + builtinName + + "() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin " + builtinName + + "() cannot have a left side!"); + + auto targetMem = state.getRequiredVarOrMem(args[0].value, "dest"); + VarDef varRDRAM = resolveArg(args[1], "RDRAM"); + + // Require size argument when dest is a register variable + if (targetMem.reg.empty() == false && args.size() != 3) { + state.throwError("Builtin " + builtinName + + "() requires size-argument when using a variable as destination!"); + } + + std::vector res; + if (varRDRAM.reg != reg::Reg::S0) { + res.push_back(asmOp("or", + {reg::Reg::S0, reg::Reg::ZERO, varRDRAM.reg})); + } + + std::vector sizeLoadOps; + // Explicit size (3-arg form) + if (args.size() == 3) { + const auto &sizeArg = args[2]; + if (sizeArg.type == ArgType::Num) { + int dmaSize = (std::stoi(sizeArg.value) - 1); + sizeLoadOps.push_back( + asmOp("ori", {reg::Reg::T0, reg::Reg::ZERO, + std::to_string(dmaSize)})); + } else { + VarDef sizeVar = state.getRequiredVarCopy(sizeArg.value, "size"); + if (sizeVar.reg != reg::Reg::T0) + state.throwError("Builtin " + builtinName + + "() requires size-argument to be in $t0!"); + sizeLoadOps.push_back( + asmOp("addiu", {reg::Reg::T0, reg::Reg::T0, "-1"})); + } + + if (!targetMem.reg.empty()) { + if (targetMem.reg != reg::Reg::S4) + state.throwError("Builtin " + builtinName + + "() requires dest. var to be in $s4!"); + } else { + sizeLoadOps.push_back( + asmOp("ori", {reg::Reg::S4, reg::Reg::ZERO, + "%lo(" + targetMem.name + ")"})); + } + } else { + // No explicit size: use declared state size + int targetSize = TYPE_SIZE.at(targetMem.type) * targetMem.arraySize; + int dmaSize = (targetSize - 1); + sizeLoadOps.push_back( + asmOp("ori", {reg::Reg::T0, reg::Reg::ZERO, + std::to_string(dmaSize)})); + sizeLoadOps.push_back( + asmOp("ori", {reg::Reg::S4, reg::Reg::ZERO, + "%lo(" + targetMem.name + ")"})); + } + + res.insert(res.end(), sizeLoadOps.begin(), sizeLoadOps.end()); + + auto flagsIt = DMA_FLAGS.find(dmaName); + auto loadflags = loadImmediate(reg::Reg::T2, + std::to_string(flagsIt != DMA_FLAGS.end() ? flagsIt->second + : 0)); + res.insert(res.end(), loadflags.begin(), loadflags.end()); + res.push_back( + asmFunction("DMAExec", + {reg::Reg::T0, reg::Reg::T1, reg::Reg::S0, + reg::Reg::S4, reg::Reg::T2})); + res.push_back(asmNOP()); + return res; +} + +// dma_await() +static std::vector +b_dma_await(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError( + "Builtin dma_await() cannot use swizzle!"); + if (varRes) + state.throwError( + "Builtin dma_await() cannot have a left side!"); + if (!args.empty()) + state.throwError( + "Builtin dma_await() requires no arguments!"); + return {asmFunction("DMAWaitIdle", {}), asmNOP()}; +} + +// swap() +static std::vector +b_swap(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + assertArgsNoSwizzle(args); + if (!swizzle.empty()) + state.throwError("Builtin swap() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin swap() cannot have a left side!"); + if (args.size() != 2) + state.throwError( + "Builtin swap() requires exactly two arguments!"); + + VarDef varA = resolveArg(args[0], "arg0"); + VarDef varB = resolveArg(args[1], "arg1"); + + if (varA.reg == varB.reg) return {}; + + std::string op = isVecType(varA.type) ? "vxor" : "xor"; + std::vector res; + res.push_back(asmOp(op, {varA.reg, varA.reg, varB.reg})); + res.push_back(asmOp(op, {varB.reg, varA.reg, varB.reg})); + res.push_back(asmOp(op, {varA.reg, varA.reg, varB.reg})); + + if (isTwoRegType(varA.type)) { + std::string ra = *reg::nextReg(varA.reg); + std::string rb = *reg::nextReg(varB.reg); + res.push_back(asmOp(op, {ra, ra, rb})); + res.push_back(asmOp(op, {rb, ra, rb})); + res.push_back(asmOp(op, {ra, ra, rb})); + } + return res; +} + +// min / max +static std::vector +b_minmax(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle, const std::string &compareOp) { + if (!swizzle.empty()) + state.throwError("Builtin min/max() cannot use swizzle!"); + if (args.size() != 2) + state.throwError( + "Builtin min/max() requires exactly two arguments!"); + + VarDef varA = resolveArg(args[0], "arg0"); + VarDef varB = resolveArg(args[1], "arg1"); + // JS max()/min() do not propagate swizzle from the function arguments + // to the comparison operation (max is element-wise, not swizzled). + varA.swizzle.clear(); + varB.swizzle.clear(); + if (!varRes) + state.throwError("Builtin min/max() needs a left-side"); + + return opCompareVec(*varRes, varA, varB, compareOp, nullptr); +} + +// invert_half / invert_half_sqrt +static std::vector +b_invert_half(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + assertArgsNoSwizzle(args); + if (args.size() != 1) + state.throwError( + "Builtin invert_half() requires exactly one argument!"); + VarDef varArg = resolveArg(args[0], "arg0"); + if (!isVecType(varArg.type)) + state.throwError( + "Builtin invert_half() requires a vector argument!"); + if (!varRes) + state.throwError("Builtin invert_half() needs a left-side"); + + VarDef argSwiz = varArg; + argSwiz.swizzle = swizzle; + return opInvertHalf(*varRes, argSwiz); +} + +static std::vector +b_invert(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + // invert = invert_half + multiply by 2 + if (swizzle.size()) + state.throwError( + "Builtin invert() cannot use swizzle, use invert_half() instead"); + auto res = b_invert_half(varRes, args, swizzle); + VarDef mulRight; + mulRight.value = 2; + auto mulAsm = opMulVec(*varRes, *varRes, mulRight, true); + res.insert(res.end(), mulAsm.begin(), mulAsm.end()); + return res; +} + +static std::vector +b_invert_half_sqrt(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + assertArgsNoSwizzle(args); + if (args.size() != 1) + state.throwError( + "Builtin invert_half_sqrt() requires exactly one argument!"); + VarDef varArg = resolveArg(args[0], "arg0"); + if (!isVecType(varArg.type)) + state.throwError( + "Builtin invert_half_sqrt() requires a vector argument!"); + if (!varRes) + state.throwError( + "Builtin invert_half_sqrt() needs a left-side"); + + // JS: varRes keeps its own swizzle (from the assignment target, + // e.g. vLenInv.w), while the swizzle on the function call + // (e.g. invert_half_sqrt(x).x) is applied only to the argument. + VarDef argSwiz = varArg; + argSwiz.swizzle = swizzle; + return opInvertSqrtHalf(*varRes, argSwiz); +} + +// select() +static std::vector +b_select(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError( + "Builtin select() cannot use swizzle!"); + if (args.size() != 2) + state.throwError( + "Builtin select() requires exactly two arguments!"); + + VarDef varLeft, varRight; + if (args[0].type == ArgType::Num) { + varLeft.reg = reg::Reg::VZERO; + varLeft.type = TypeClass::Vec16; + } else { + varLeft = resolveArg(args[0], "arg0"); + } + if (args[1].type == ArgType::Num) { + auto pIt = POW2_SWIZZLE_VAR.find(std::stoll(args[1].value)); + if (pIt == POW2_SWIZZLE_VAR.end()) + state.throwError( + "Second arg must be a variable or power-of-two constant!"); + varRight.reg = pIt->second.reg; + varRight.swizzle = pIt->second.swizzle; + varRight.type = TypeClass::Vec16; + } else { + varRight = resolveArg(args[1], "arg1"); + } + if (!varRes) + state.throwError("Builtin select() needs a left-side"); + + auto regsDst = ops::getVec32Regs(*varRes); + auto regsL = ops::getVec32Regs(varLeft); + auto regsR = ops::getVec32Regs(varRight); + if (!isTwoRegType(varRes->type)) { + regsL.first = varLeft.reg; + regsR.first = varRight.reg; + } + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + std::string swSuffix = + sit != SWIZZLE_MAP.end() ? sit->second : ""; + if (swSuffix == ".v") swSuffix = ""; + + std::vector args1 = {regsDst.first, regsL.first, + regsR.first + swSuffix}; + std::vector args2 = {regsDst.second, regsL.second, + regsR.second + swSuffix}; + return {asmOp("vmrg", args1), asmOp("vmrg", args2)}; +} + +// assert() +static std::vector +b_assert(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin assert() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin assert() cannot have a left side!"); + if (args.size() != 1) + state.throwError("Builtin assert() requires exactly one argument!"); + if (args[0].type != ArgType::Num) + state.throwError( + "Builtin assert() requires the argument to be a number!"); + int code = std::stoi(args[0].value); + return {asmOp("lui", {reg::Reg::AT, std::to_string(code)}), + asmOp("j", {LABEL_ASSERT}), asmNOP()}; +} + +// asm() — raw inline assembly +static std::vector +b_asm(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin asm() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin asm() cannot have a left side!"); + if (args.empty() || args[0].type != ArgType::String) + state.throwError( + "Builtin asm() requires the first argument to be a string!"); + + std::string str = args[0].value; + for (size_t i = 1; i < args.size(); ++i) { + const auto &arg = args[i]; + std::string replacement; + if (arg.type == ArgType::Num) { + replacement = arg.value; + } else { + const VarDef *varArg = state.getRequiredVar(arg.value, "arg" + + std::to_string(i)); + replacement = varArg->reg; + } + std::string placeholder = "%" + std::to_string(i - 1); + size_t pos = 0; + while ((pos = str.find(placeholder, pos)) != std::string::npos) { + str.replace(pos, placeholder.length(), replacement); + pos += replacement.length(); + } + } + return {asmInline(str, {"# inline-ASM"})}; +} + +// transpose() +// Valid transpose registers: $v00, $v08, $v16, $v24 +static const std::vector VALID_TRANSPOSE_REGS = { + "$v00", "$v08", "$v16", "$v24"}; + +static std::vector +b_transpose(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin transpose() cannot use swizzle!"); + if (!varRes) + state.throwError("Builtin transpose() needs a left-side"); + if (!reg::isVecReg(varRes->reg)) + state.throwError( + "Builtin transpose() must store the result into a vector!"); + if (args.size() != 4) + state.throwError("Builtin transpose() requires 4 arguments!"); + + VarDef varSrc = resolveArg(args[0], "arg0"); + if (!isVecType(varSrc.type)) + state.throwError( + "Builtin transpose() requires first argument to be a vector!"); + VarDef buffVar = resolveArg(args[1], "arg1"); + if (isVecType(buffVar.type)) + state.throwError( + "Builtin transpose() requires second argument to be a scalar!"); + if (args[2].type != ArgType::Num) + state.throwError( + "Builtin transpose() requires third argument to be a number!"); + if (args[3].type != ArgType::Num) + state.throwError( + "Builtin transpose() requires fourth argument to be a number!"); + + int dimX = std::stoi(args[2].value); + int dimY = std::stoi(args[3].value); + if (dimX < 1 || dimX > 8 || dimY < 1 || dimY > 8) + state.throwError( + "Builtin transpose() requires X and Y dimension to be between 1 " + "and 8!"); + + if (std::find(VALID_TRANSPOSE_REGS.begin(), + VALID_TRANSPOSE_REGS.end(), + varRes->reg) == VALID_TRANSPOSE_REGS.end()) + state.throwError("Builtin transpose() requires target register to be " + "$v00, $v08, $v16 or $v24!"); + if (std::find(VALID_TRANSPOSE_REGS.begin(), + VALID_TRANSPOSE_REGS.end(), + varSrc.reg) == VALID_TRANSPOSE_REGS.end()) + state.throwError("Builtin transpose() requires source register to be " + "$v00, $v08, $v16 or $v24!"); + + bool isInPlace = (varSrc.reg == varRes->reg); + bool is8x8 = (dimX > 4 || dimY > 4); + + // Barrier to prevent reordering across the transpose (matching JS) + state.addAnnotation("Barrier", state.generateLabel()); + + std::string bufReg = buffVar.reg; + + std::vector res; + // STV stores + if (!isInPlace) + res.push_back( + asmOp("stv", {varSrc.reg, "0", "0", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "2", "16", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "4", "32", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "6", "48", bufReg})); + if (is8x8) + res.push_back(asmOp("stv", {varSrc.reg, "8", "64", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "10", "80", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "12", "96", bufReg})); + res.push_back(asmOp("stv", {varSrc.reg, "14", "112", bufReg})); + + // LTV loads + res.push_back(asmOp("ltv", {varRes->reg, "14", "16", bufReg})); + res.push_back(asmOp("ltv", {varRes->reg, "12", "32", bufReg})); + res.push_back(asmOp("ltv", {varRes->reg, "10", "48", bufReg})); + if (is8x8) + res.push_back(asmOp("ltv", {varRes->reg, "8", "64", bufReg})); + res.push_back(asmOp("ltv", {varRes->reg, "6", "80", bufReg})); + res.push_back(asmOp("ltv", {varRes->reg, "4", "96", bufReg})); + res.push_back(asmOp("ltv", {varRes->reg, "2", "112", bufReg})); + if (!isInPlace) + res.push_back( + asmOp("ltv", {varRes->reg, "0", "0", bufReg})); + + return res; +} + +// asm_op() +static std::vector +b_asm_op(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (!swizzle.empty()) + state.throwError("Builtin asm_op() cannot use swizzle!"); + if (varRes) + state.throwError("Builtin asm_op() cannot have a left side!"); + if (args.empty() || args[0].type != ArgType::String) + state.throwError( + "Builtin asm_op() requires the first argument to be a opcode!"); + + std::vector asmArgs; + for (size_t i = 1; i < args.size(); ++i) { + if (args[i].type == ArgType::Num) { + asmArgs.push_back(args[i].value); + } else { + const VarDef *v = state.getRequiredVar(args[i].value, "arg"); + std::string sw; + if (isVecType(v->type)) { + auto sit = SWIZZLE_MAP.find(args[i].swizzle); + sw = sit != SWIZZLE_MAP.end() ? sit->second : ""; + } + asmArgs.push_back(v->reg + sw); + } + } + return {asmOp(args[0].value, asmArgs)}; +} + +// asm_include() +static std::vector +b_asm_include(const VarDef *varRes, + const std::vector &args, + const std::string &swizzle) { + if (args.empty() || args[0].type != ArgType::String) + state.throwError("Builtin asm_include() requires a path argument!"); + + std::vector res; + // Emit #defines for all scalar registers + for (size_t i = 0; i < reg::REGS_SCALAR.size(); ++i) { + if (i == 1) continue; // skip $at + std::string name = reg::REGS_SCALAR[i].substr(1); + res.push_back( + asmInline("#define " + name + " $" + std::to_string(i))); + } + res.push_back(asmInline(".set at")); + res.push_back(asmInline(".set macro")); + res.push_back( + asmInline("#include \"" + args[0].value + "\"")); + res.push_back(asmInline(".set noreorder")); + res.push_back(asmInline(".set noat")); + res.push_back(asmInline(".set nomacro")); + for (const auto ® : reg::REGS_SCALAR) { + res.push_back(asmInline("#undef " + reg.substr(1))); + } + return res; +} + +// --- Registry --------------------------------------------------------- + +using BuiltinMap = + std::unordered_map; + +static BuiltinMap buildRegistry() { + BuiltinMap m; + + m["load"] = b_load; + m["store"] = b_store; + m["load_vec_u8"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_load_ex(vr, args, swizzle, true, false, false); + }; + m["load_vec_s8"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_load_ex(vr, args, swizzle, true, true, false); + }; + m["store_vec_u8"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_store_ex(vr, args, swizzle, true, false, false); + }; + m["store_vec_s8"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_store_ex(vr, args, swizzle, true, true, false); + }; + m["load_unaligned"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_load_ex(vr, args, swizzle, false, true, true); + }; + m["store_unaligned"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) { + return b_store_ex(vr, args, swizzle, false, true, true); + }; + m["abs"] = b_abs; + m["clear_vcc"] = b_clear_vcc; + m["get_vcc"] = b_get_vcc; + m["set_vcc"] = b_set_vcc; + m["clip"] = b_clip; + m["get_acc"] = b_get_acc; + m["swap"] = b_swap; + m["select"] = b_select; + m["assert"] = b_assert; + m["asm"] = b_asm; + m["get_cmd_address"] = b_get_cmd_address; + m["load_arg"] = b_load_arg; + m["dma_await"] = b_dma_await; + + // Stubs for other builtins + m["load_transposed"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) -> std::vector { + if (!swizzle.empty()) state.throwError("Builtin load_transposed() cannot use swizzle!"); + if (!vr) state.throwError("Builtin load_transposed() needs a left-side"); + if (!reg::isVecReg(vr->reg)) state.throwError("Builtin load_transposed() must store the result into a vector!"); + if (args.size() < 2 || args.size() > 3) state.throwError("Builtin load_transposed() requires 2 or 3 arguments!"); + if (args[0].type != ArgType::Num) state.throwError("Builtin load_transposed() requires first argument to be a number (row offset 0-7)!"); + int row = std::stoi(args[0].value); + if (row < 0 || row > 7) state.throwError("Builtin load_transposed() requires first argument (row) to be a number between 0 and 7!"); + int offset = 0; + if (args.size() >= 3) { + if (args[2].type != ArgType::Num) state.throwError("Builtin load_transposed() requires third argument to be a number (offset in steps of 0x10)!"); + offset = std::stoi(args[2].value); + if (offset % 16 != 0) state.throwError("Builtin load_transposed() requires offset to be multiple of 16!"); + } + auto addrMem = state.getRequiredVarOrMem(args[1].value, "addr"); + if (!addrMem.reg.empty() && reg::isVecReg(addrMem.reg)) state.throwError("Builtin load_transposed() requires second argument to be a scalar variable!"); + // Register must be v00/v08/v16/v24 + std::string reg = vr->reg; + if (reg != "$v00" && reg != "$v08" && reg != "$v16" && reg != "$v24") + state.throwError("Builtin load_transposed() requires result register to be $v00, $v08, $v16 or $v24!"); + std::string baseReg = addrMem.reg.empty() ? "$zero" : addrMem.reg; + std::string offStr = std::to_string(offset); + if (!addrMem.reg.empty()) { + return {asmOp("ltv", {reg, std::to_string(row * 2), offStr, baseReg})}; + } + auto loadAt = loadImmediate("$at", "%lo(" + addrMem.name + ")"); + loadAt.push_back(asmOp("ltv", {reg, std::to_string(row * 2), offStr, "$at"})); + return loadAt; + }; + m["store_transposed"] = [](const VarDef *vr, + const std::vector &args, + const std::string &swizzle) -> std::vector { + if (!swizzle.empty()) state.throwError("Builtin store_transposed() cannot use swizzle!"); + if (vr) state.throwError("Builtin store_transposed() has no left-side"); + if (args.size() < 3 || args.size() > 4) state.throwError("Builtin store_transposed() requires 3 or 4 arguments!"); + // args[0] = value to store, args[1] = row, args[2] = addr, args[3] = offset (optional) + const VarDef *valVar = state.getRequiredVar(args[0].value, "arg0"); + if (!reg::isVecReg(valVar->reg)) state.throwError("Builtin store_transposed() must target a vector register!"); + if (args[1].type != ArgType::Num) state.throwError("Builtin store_transposed() requires second argument to be a number (row offset 0-7)!"); + int row = std::stoi(args[1].value); + if (row < 0 || row > 7) state.throwError("Builtin store_transposed() requires second argument (row) to be a number between 0 and 7!"); + int offset = 0; + if (args.size() >= 4) { + if (args[3].type != ArgType::Num) state.throwError("Builtin store_transposed() requires fourth argument to be a number (offset in steps of 0x10)!"); + offset = std::stoi(args[3].value); + if (offset % 16 != 0) state.throwError("Builtin store_transposed() requires offset to be multiple of 16!"); + } + auto addrMem = state.getRequiredVarOrMem(args[2].value, "addr"); + if (!addrMem.reg.empty() && reg::isVecReg(addrMem.reg)) state.throwError("Builtin store_transposed() requires third argument to be a scalar variable!"); + if (valVar->reg != "$v00" && valVar->reg != "$v08" && valVar->reg != "$v16" && valVar->reg != "$v24") + state.throwError("Builtin store_transposed() requires target register to be $v00, $v08, $v16 or $v24!"); + std::string baseReg = addrMem.reg.empty() ? "$zero" : addrMem.reg; + std::string offStr = std::to_string(offset); + if (!addrMem.reg.empty()) { + return {asmOp("stv", {valVar->reg, std::to_string(row * 2), offStr, baseReg})}; + } + auto loadAt = loadImmediate("$at", "%lo(" + addrMem.name + ")"); + loadAt.push_back(asmOp("stv", {valVar->reg, std::to_string(row * 2), offStr, "$at"})); + return loadAt; + }; + m["transpose"] = b_transpose; + m["asm_op"] = b_asm_op; + m["asm_include"] = b_asm_include; + + // get_acc_high/mid/low + m["get_acc_high"] = [](const VarDef *vr, + const std::vector &a, + const std::string &s) { + return b_get_acc_part(vr, a, s, reg::RegCop2::ACC_HI); + }; + m["get_acc_mid"] = [](const VarDef *vr, + const std::vector &a, + const std::string &s) { + return b_get_acc_part(vr, a, s, reg::RegCop2::ACC_MD); + }; + m["get_acc_low"] = [](const VarDef *vr, + const std::vector &a, + const std::string &s) { + return b_get_acc_part(vr, a, s, reg::RegCop2::ACC_LO); + }; + + // MFC0 reads + auto addMfc0Read = [&](const char *name, const char *rdpReg) { + m[name] = [name, rdpReg](const VarDef *vr, + const std::vector &a, + const std::string &s) -> std::vector { + if (!s.empty()) + state.throwError(std::string("Builtin ") + name + + "() cannot use swizzle!"); + if (!a.empty()) + state.throwError(std::string("Builtin ") + name + + "() requires no arguments!"); + return b_mfc0_read(vr, rdpReg, name); + }; + }; + addMfc0Read("get_dma_busy", reg::RegCop0::DMA_BUSY); + addMfc0Read("get_rdp_start", reg::RegCop0::DP_START); + addMfc0Read("get_rdp_end", reg::RegCop0::DP_END); + addMfc0Read("get_rdp_current", reg::RegCop0::DP_CURRENT); + addMfc0Read("get_ticks", reg::RegCop0::DP_CLOCK); + + // MTC0 writes + auto addMtc0Write = [&](const char *name, const char *rdpReg) { + m[name] = [name, rdpReg](const VarDef *vr, + const std::vector &a, + const std::string &s) -> std::vector { + if (!s.empty()) + state.throwError(std::string("Builtin ") + name + + "() cannot use swizzle!"); + return b_mtc0_write(vr, a, rdpReg); + }; + }; + addMtc0Write("set_rdp_start", reg::RegCop0::DP_START); + addMtc0Write("set_rdp_end", reg::RegCop0::DP_END); + addMtc0Write("set_rdp_current", reg::RegCop0::DP_CURRENT); + addMtc0Write("set_dma_addr_rsp", reg::RegCop0::DMA_SPADDR); + addMtc0Write("set_dma_addr_rdram", reg::RegCop0::DMA_RAMADDR); + addMtc0Write("set_dma_write", reg::RegCop0::DMA_WRITE); + addMtc0Write("set_dma_read", reg::RegCop0::DMA_READ); + + // DMA + auto addDma = [&](const char *name, const char *dmaFlag) { + m[name] = [name, dmaFlag](const VarDef *vr, + const std::vector &a, + const std::string &s) { + return b_dma(vr, a, s, name, dmaFlag); + }; + }; + addDma("dma_in", "DMA_IN"); + addDma("dma_out", "DMA_OUT"); + addDma("dma_in_async", "DMA_IN_ASYNC"); + addDma("dma_out_async", "DMA_OUT_ASYNC"); + + // print / printf + auto b_print = [](const VarDef *varRes, + const std::vector &args, + const std::string &) -> std::vector { + if (varRes) + state.throwError("Builtin print() cannot have a left side!"); + if (args.empty()) + state.throwError( + "Builtin print() requires at least one argument!"); + + for (const auto &arg : args) { + if (arg.type == ArgType::Num) + state.throwError("Builtin print() requires all arguments to " + "be variables or strings!"); + } + + auto mainType = args[0].type; + for (const auto &arg : args) { + if (arg.type != mainType) + state.throwError( + "Builtin print() requires all arguments to be of the " + "same type!"); + } + + std::vector res; + res.push_back(asmInline(".set macro", {"# print"})); + + if (mainType == ArgType::String) { + std::vector strArgs; + for (const auto &arg : args) + strArgs.push_back("\"" + arg.value + "\""); + res.push_back(asmInline("emux_log_string", strArgs)); + } else { + // Resolve first arg to determine scalar vs vector + VarDef arg0 = resolveArg(args[0], "arg0"); + state.logInfo("Info: print() variable '" + arg0.name + + "' is " + arg0.reg); + bool isVector = isVecType(arg0.type); + + for (size_t i = 1; i < args.size(); ++i) { + VarDef argVar = resolveArg(args[i], "arg" + std::to_string(i)); + state.logInfo("Info: print() variable '" + argVar.name + + "' is " + argVar.reg); + if (isVecType(argVar.type) != isVector) + state.throwError( + "Builtin print() doesn't allow mixed scalar/vector " + "arguments!"); + } + + std::string op = isVector ? "emux_dump_vpr" : "emux_dump_gpr"; + std::vector regArgs; + for (const auto &arg : args) { + VarDef argVar = resolveArg(arg, "arg"); + regArgs.push_back(argVar.reg); + } + res.push_back(asmInline(op, regArgs)); + } + + res.push_back(asmInline(".set noat", {"# print"})); + res.push_back(asmInline(".set nomacro", {"# print"})); + return res; + }; + + auto b_printf = [](const VarDef *varRes, + const std::vector &args, + const std::string &) -> std::vector { + if (varRes) + state.throwError("Builtin printf() cannot have a left side!"); + if (args.empty()) + state.throwError( + "Builtin printf() requires at least one argument!"); + if (args[0].type != ArgType::String) + state.throwError( + "Builtin printf() requires first argument to be a string!"); + + std::vector res; + res.push_back(asmInline(".set macro", {"# print"})); + + std::string fmt = args[0].value; + std::string fmtString; + size_t argIdx = 1; + + // Parse format string for %v/%d/%u/%x/%f specifiers (matching JS + // regex: /(%[vduxf])/) + size_t pos = 0; + while (pos < fmt.size()) { + size_t pct = fmt.find('%', pos); + if (pct == std::string::npos) { + fmtString += fmt.substr(pos); + break; + } + fmtString += fmt.substr(pos, pct - pos); + if (pct + 1 < fmt.size()) { + char specChar = fmt[pct + 1]; + if (specChar == 'v' || specChar == 'd' || specChar == 'u' || + specChar == 'x' || specChar == 'f') { + std::string spec = fmt.substr(pct, 2); + if (argIdx < args.size()) { + const auto &val = args[argIdx++]; + if (val.type == ArgType::Var) { + VarDef refVar = resolveArg(val, "arg" + + std::to_string(argIdx)); + if (reg::isVecReg(refVar.reg)) { + auto it = SWIZZLE_MAP.find(val.swizzle); + std::string sw = + it != SWIZZLE_MAP.end() ? it->second : ""; + if (refVar.type == TypeClass::Vec32) { + fmtString += "%f" + refVar.reg.substr(1) + sw; + } else { + fmtString += "%d" + refVar.reg.substr(1) + sw; + } + } else { + fmtString += spec + refVar.reg.substr(1); + } + } + } + pos = pct + 2; + } else { + fmtString += fmt[pct]; + pos = pct + 1; + } + } else { + fmtString += fmt[pct]; + pos = pct + 1; + } + } + + res.push_back( + asmInline("emux_printf", {"\"" + fmtString + "\""})); + res.push_back(asmInline(".set noat", {"# print"})); + res.push_back(asmInline(".set nomacro", {"# print"})); + return res; + }; + + // invert + m["invert_half"] = b_invert_half; + m["invert"] = b_invert; + m["invert_half_sqrt"] = b_invert_half_sqrt; + + addMtc0Write("set_rsp_status", reg::RegCop0::SP_STATUS); + m["print"] = b_print; + m["printf"] = b_printf; + m["min"] = [](const VarDef *vr, + const std::vector &a, + const std::string &s) { return b_minmax(vr, a, s, "<"); }; + m["max"] = [](const VarDef *vr, + const std::vector &a, + const std::string &s) { + return b_minmax(vr, a, s, ">="); + }; + + return m; +} + +static BuiltinMap registry = buildRegistry(); + +const BuiltinFn *lookup(const std::string &name) { + auto it = registry.find(name); + return it != registry.end() ? &it->second : nullptr; +} + +} // namespace rspl::builtins diff --git a/cpp/src/builtins.h b/cpp/src/builtins.h new file mode 100644 index 0000000..05c52ca --- /dev/null +++ b/cpp/src/builtins.h @@ -0,0 +1,21 @@ +#pragma once + +#include "asm.h" +#include "state.h" + +#include +#include +#include + +namespace rspl::builtins { + +using BuiltinFn = + std::function(const VarDef *, // varRes (null if no left side) + const std::vector &, // args + const std::string & // swizzle + )>; + +// Look up a builtin by name. Returns nullptr if not found. +const BuiltinFn *lookup(const std::string &name); + +} // namespace rspl::builtins diff --git a/cpp/src/main.cpp b/cpp/src/main.cpp new file mode 100644 index 0000000..59a83be --- /dev/null +++ b/cpp/src/main.cpp @@ -0,0 +1,174 @@ +/** + * RSPL C++ transpiler — CLI entry point. + */ + +#include "pipeline.h" +#include "preproc.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct CliArgs { + std::string inputFile; + std::string outputFile; + bool astDump = false; + bool optimize = true; + bool reorder = false; + int optimizeTime = 30'000; + int optWorkers = 0; // 0 = auto (hw threads - 1) + bool rspqWrapper = true; + bool help = false; + std::vector defines; // "KEY=VALUE" pairs +}; + +void printHelp() { + std::cout << R"(Usage: rspl [options] + +Options: + -o Output .S file (default: input base + .S) + -D KEY=VALUE Define a preprocessor constant + --opt-time=N Optimizer time budget in seconds (default: 30) + --opt-workers=N Number of reorder worker threads (default: auto) + --no-optimize Disable optimization + --reorder Enable instruction reordering + --no-rspq Disable RSPQ wrapper + --ast-dump Print parsed AST and exit + -h, --help Show this help +)"; +} + +CliArgs parseArgs(int argc, char **argv) { + CliArgs args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "-h" || arg == "--help") { args.help = true; } + else if (arg == "--ast-dump") { args.astDump = true; } + else if (arg == "-o" && i + 1 < argc) { args.outputFile = argv[++i]; } + else if (arg == "--no-optimize") { args.optimize = false; } + else if (arg == "--reorder") { args.reorder = true; } + else if (arg == "--no-rspq") { args.rspqWrapper = false; } + else if (arg == "-D" && i + 1 < argc) { args.defines.push_back(argv[++i]); } + else if (arg.starts_with("-D")) { args.defines.push_back(arg.substr(2)); } + else if (arg.starts_with("--opt-time=")) { args.optimizeTime = std::stoi(arg.substr(11)) * 1000; } + else if (arg.starts_with("--opt-workers=")) { args.optWorkers = std::stoi(arg.substr(14)); } + else if (!arg.starts_with("-")) { args.inputFile = arg; } + } + return args; +} + +std::string readFile(const std::string &path) { + std::ifstream f(path); + if (!f) { std::cerr << "Error: cannot open file: " << path << "\n"; std::exit(1); } + std::ostringstream ss; ss << f.rdbuf(); return ss.str(); +} + +void writeFile(const std::string &path, const std::string &content) { + std::ofstream f(path); + if (!f) { std::cerr << "Error: cannot write file: " << path << "\n"; std::exit(1); } + f << content; +} + +std::string execJsParser(const std::string &rsplPath, bool skipPreproc) { + const char *scriptPath = std::getenv("RSPL_PARSE_JS"); + std::string cmd; + if (scriptPath) { + cmd = std::string("node ") + scriptPath; + } else { + cmd = "node scripts/parse.js"; + } + cmd += skipPreproc ? " --preprocessed " : " "; + cmd += rsplPath; + + FILE *pipe = popen(cmd.c_str(), "r"); + if (!pipe) { std::cerr << "Error: cannot start JS parser\n"; std::exit(1); } + std::string result; + char buf[4096]; + while (fgets(buf, sizeof(buf), pipe)) result += buf; + int rc = pclose(pipe); + if (rc != 0) { std::cerr << "Error: JS parser exited with code " << rc << "\n"; std::exit(1); } + return result; +} + +} // namespace + +int main(int argc, char **argv) { + CliArgs args = parseArgs(argc, argv); + + if (args.help) { printHelp(); return 0; } + + // Read source + std::string source; + if (!args.inputFile.empty()) { + source = readFile(args.inputFile); + } else { + std::ostringstream ss; ss << std::cin.rdbuf(); source = ss.str(); + } + if (source.empty()) { std::cerr << "Error: no input provided\n"; return 1; } + + // Parse defines + std::unordered_map defines; + for (const auto &d : args.defines) { + auto eq = d.find('='); + std::string key = d.substr(0, eq); + std::string val = (eq != std::string::npos) ? d.substr(eq + 1) : "1"; + defines[key] = {key, val}; + } + + // Run C++ preprocessor (strip comments + defines + includes) + std::string sourceDir = "."; + if (!args.inputFile.empty()) { + auto slash = args.inputFile.find_last_of('/'); + if (slash != std::string::npos) + sourceDir = args.inputFile.substr(0, slash); + } + std::string preprocessed = rspl::preprocFull(source, defines, sourceDir); + + // Write preprocessed source to temp file for JS parser + std::string tmpPath = "/tmp/rspl_preprocessed.rspl"; + writeFile(tmpPath, preprocessed); + + // Route through JS parser with --preprocessed flag + std::string astJson = execJsParser(tmpPath, true); + + if (args.astDump) { std::cout << astJson; return 0; } + + rspl::TranspileConfig cfg; + cfg.rspqWrapper = args.rspqWrapper; + cfg.optimize = args.optimize; + cfg.reorder = args.reorder; + cfg.optimizeTime = args.optimizeTime; + cfg.optWorkers = args.optWorkers; + cfg.sourceDir = sourceDir; + + auto result = rspl::runPipeline(astJson, cfg); + + // Determine output path: explicit -o flag, or derive from input + std::string outPath = args.outputFile; + if (outPath.empty() && !args.inputFile.empty()) { + outPath = args.inputFile; + // Replace .rspl extension with .S + if (outPath.ends_with(".rspl")) + outPath = outPath.substr(0, outPath.size() - 5) + ".S"; + else + outPath += ".S"; + } + + if (!outPath.empty()) { + writeFile(outPath, result.asm_); + } + + std::cerr << "// DMEM: " << result.sizeDMEM + << " bytes, IMEM: " << result.sizeIMEM << " bytes" << std::endl; + + if (!result.warn.empty()) + std::cerr << result.warn << std::flush; + + return 0; +} diff --git a/cpp/src/operations/branch.cpp b/cpp/src/operations/branch.cpp new file mode 100644 index 0000000..149995d --- /dev/null +++ b/cpp/src/operations/branch.cpp @@ -0,0 +1,147 @@ +#include "branch.h" +#include "scalar.h" + +#include "../asm.h" +#include "../registers.h" +#include "../state.h" +#include "../types.h" + +#include +#include + +namespace rspl::ops { + +static const std::unordered_map BRANCH_INVERT = []() { + std::unordered_map m; + m[getOpcode("beq")] = getOpcode("bne"); + m[getOpcode("bne")] = getOpcode("beq"); + m[getOpcode("bgezal")] = getOpcode("bltzal"); + m[getOpcode("bltzal")] = getOpcode("bgezal"); + m[getOpcode("bgez")] = getOpcode("bltz"); + m[getOpcode("bltz")] = getOpcode("bgez"); + m[getOpcode("blez")] = getOpcode("bgtz"); + m[getOpcode("bgtz")] = getOpcode("blez"); + return m; +}(); + +static const std::unordered_map + ZERO_COMB_BRANCH = []() { + std::unordered_map m; + m["<"] = getOpcode("bltz"); + m["<="] = getOpcode("blez"); + m[">"] = getOpcode("bgtz"); + m[">="] = getOpcode("bgez"); + return m; + }(); + +Opcode invertBranchOp(Opcode op) { + auto it = BRANCH_INVERT.find(op); + if (it == BRANCH_INVERT.end()) { + state.throwError("Cannot invert branch operation: " + + getOpcodeName(op)); + } + return it->second; +} + +std::vector opBranch(const ast::CompareExpr &compare, + const std::string &labelElse, + bool invert) { + bool isImmediate = (compare.right.type == ArgType::Num); + std::string regTestRes; + if (isImmediate) { + regTestRes = reg::Reg::AT; + } else { + const VarDef *var = + state.getRequiredVar(compare.right.value, "compare"); + regTestRes = var->reg; + } + + // Zero-checks can use $zero directly + if (isImmediate && compare.right.value == "0") { + isImmediate = false; + regTestRes = reg::Reg::ZERO; + } + + VarDef varLeft = + state.getRequiredVarCopy(compare.left.value, "left"); + std::string regLeft = varLeft.reg; + + // == and != are simple + if (compare.op == "==" || compare.op == "!=") { + Opcode opBranch = compare.op == "==" ? getOpcode("bne") : getOpcode("beq"); + if (invert) opBranch = invertBranchOp(opBranch); + + std::vector res; + if (isImmediate) { + auto load = loadImmediate(reg::Reg::AT, compare.right.value); + res.insert(res.end(), load.begin(), load.end()); + } + res.push_back( + asmBranch(getOpcodeName(opBranch), {regLeft, regTestRes, labelElse}, labelElse)); + res.push_back(asmNOP()); + return res; + } + + // Zero-combination branch ops + if (!isImmediate && regTestRes == reg::Reg::ZERO) { + auto it = ZERO_COMB_BRANCH.find(compare.op); + if (it != ZERO_COMB_BRANCH.end()) { + Opcode op = it->second; + if (!invert) op = invertBranchOp(op); + return {asmBranch(getOpcodeName(op), {regLeft, labelElse}, labelElse), asmNOP()}; + } + } + + // Transform > and <= into < and >= with swapped args + std::string op = compare.op; + std::string valR = isImmediate ? compare.right.value : regTestRes; + + if (op == ">" || op == "<=") { + if (isImmediate) { + // Increment immediate to handle the "=" part + valR = std::to_string(std::stoll(valR) + 1); + op = op == ">" ? ">=" : "<"; + } else { + op = op == ">" ? "<" : ">="; + std::swap(regLeft, regTestRes); + } + } + + // If immediate doesn't fit in 16-bit signed range, load it into $at first + if (isImmediate) { + int64_t immVal = std::stoll(valR); + if (immVal < -32768 || immVal > 32767) { + isImmediate = false; + regTestRes = reg::Reg::AT; + } + } + + // For = comparisons, use slt + branch + std::string opLessThan = + "slt" + (isImmediate ? std::string("i") : "") + + (isSigned(varLeft.type) ? "" : "u"); + + if (op == "<" || op == ">=") { + Opcode brOp = op == "<" ? getOpcode("beq") : getOpcode("bne"); + if (invert) brOp = invertBranchOp(brOp); + + std::vector res; + if (!isImmediate && regTestRes == reg::Reg::AT) { + auto load = loadImmediate(reg::Reg::AT, valR); + res.insert(res.end(), load.begin(), load.end()); + } + res.push_back( + asmOp(opLessThan, {reg::Reg::AT, regLeft, + isImmediate ? valR : regTestRes})); + res.push_back(asmBranch( + getOpcodeName(brOp), {reg::Reg::AT, reg::Reg::ZERO, labelElse}, labelElse)); + res.push_back(asmNOP()); + return res; + } + + state.throwError( + "Unknown comparison operator: " + compare.op, {}); + return {}; +} + +} // namespace rspl::ops diff --git a/cpp/src/operations/branch.h b/cpp/src/operations/branch.h new file mode 100644 index 0000000..f0510c3 --- /dev/null +++ b/cpp/src/operations/branch.h @@ -0,0 +1,17 @@ +#pragma once + +#include "../asm.h" +#include "../ast.h" + +#include +#include + +namespace rspl::ops { + +Opcode invertBranchOp(Opcode op); + +std::vector opBranch(const ast::CompareExpr &compare, + const std::string &labelElse, + bool invert = false); + +} // namespace rspl::ops diff --git a/cpp/src/operations/scalar.cpp b/cpp/src/operations/scalar.cpp new file mode 100644 index 0000000..ba2e0a2 --- /dev/null +++ b/cpp/src/operations/scalar.cpp @@ -0,0 +1,345 @@ +#include "scalar.h" + +#include "../asm.h" +#include "../registers.h" +#include "../state.h" +#include "../swizzle.h" +#include "../types.h" + +#include +#include +#include + +namespace rspl::ops { + +static void assertScalarVars(const VarDef &varLeft, + const VarDef *varRight = nullptr) { + // Memory labels (empty reg) are always treated as scalar addresses + if ((isVecType(varLeft.type) && !varLeft.reg.starts_with("%lo")) || + (varRight && !varRight->reg.empty() && + !varRight->reg.starts_with("%lo") && + isVecType(varRight->type))) { + state.throwError( + "Scalar-Operation requires all variables to be scalars!"); + } +} + +// Precompute mul->shift table +static int mulToShift(int64_t val) { + for (int i = 0; i < 32; i++) { + if (static_cast(1LL << i) == val) return i; + } + return -1; +} + +std::vector loadImmediate(const std::string ®Dst, + const std::string &valueStr) { + // Labels are always ≤16 bit, use ori + if (!valueStr.empty() && valueStr[0] == '%') { + return {asmOp("ori", {regDst, "$zero", valueStr})}; + } + + int64_t signedVal; + try { + signedVal = std::stoll(valueStr); + } catch (...) { + // Non-numeric strings treated as labels + return {asmOp("ori", {regDst, "$zero", valueStr})}; + } + uint32_t valueU32 = static_cast(signedVal); + + if (valueU32 == 0) { + return {asmOp("or", {regDst, "$zero", "$zero"})}; + } + if (u32InS16Range(valueU32)) { + int16_t valS16 = static_cast(valueU32 & 0xFFFF); + return {asmOp("addiu", {regDst, "$zero", std::to_string(valS16)})}; + } + if (valueU32 <= 0xFFFF) { + return {asmOp("ori", {regDst, "$zero", toHex(valueU32)})}; + } + if ((valueU32 & 0xFFFF) == 0) { + return {asmOp("lui", {regDst, toHex(valueU32 >> 16)})}; + } + return {asmOp("lui", {regDst, toHex(valueU32 >> 16)}), + asmOp("ori", {regDst, regDst, toHex(valueU32 & 0xFFFF)})}; +} + +std::vector opMove(const VarDef &varRes, + const VarDef &varRight) { + if (!varRes.swizzle.empty()) + state.throwError("Swizzling not allowed on scalar variables!"); + + if (!varRight.reg.empty()) { + if (!varRight.swizzle.empty() && !isVecType(varRight.type)) { + state.throwError("Swizzling not allowed for scalar operations!"); + } + if (varRes.reg == varRight.reg) { + state.logWarning("Self-assignment detected, this is a NOP!", {}); + return {}; + } + if (!varRight.swizzle.empty()) { + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit == SWIZZLE_MAP.end()) + state.throwError("Unknown swizzle: " + varRight.swizzle); + if (varRight.type == TypeClass::Vec16) { + return {asmOp("mfc2", + {varRes.reg, varRight.reg + sit->second})}; + } + return {asmOp("mfc2", {varRes.reg, varRight.reg + sit->second}), + asmOp("andi", {varRes.reg, varRes.reg, "0xFFFF"}), + asmOp("mfc2", {"$at", varRight.reg + sit->second}), + asmOp("sll", {"$at", "$at", "16"}), + asmOp("or", {varRes.reg, varRes.reg, "$at"})}; + } + // Label references should go through loadImmediate for proper op selection + if (!varRight.reg.empty() && varRight.reg[0] == '%') { + return loadImmediate(varRes.reg, varRight.reg); + } + return {asmOp("or", {varRes.reg, "$zero", varRight.reg})}; + } + // varRight is a constant + return loadImmediate(varRes.reg, varRight.reg.empty() + ? std::to_string(static_cast(varRight.value)) + : varRight.reg); +} + +std::vector opLoad(const VarDef &varRes, const VarOrMem &varLoc, + const VarOrMem &varOffset) { + // Map type to load opcode + static const std::unordered_map loadMap = { + {"u8", "lbu"}, {"s8", "lb"}, {"u16", "lhu"}, + {"s16", "lh"}, {"u32", "lw"}, {"s32", "lw"}, + }; + auto it = loadMap.find(toString(varRes.type)); + std::string opName = it != loadMap.end() ? it->second : "lw"; + + // Build the offset string + std::string offsetStr; + if (!varOffset.reg.empty() && varOffset.reg[0] != '%') { + offsetStr = varOffset.reg; + } else if (!varOffset.name.empty() && !varOffset.reg.empty() && varOffset.reg[0] == '%') { + offsetStr = varOffset.reg; // already formatted like %lo(NAME) + } else if (!varOffset.name.empty()) { + offsetStr = "%lo(" + varOffset.name + ")"; + } else if (varOffset.reg.empty()) { + offsetStr = "0"; + } else { + offsetStr = varOffset.reg; + } + + if (!varLoc.reg.empty()) { + return {asmOp(opName, {varRes.reg, offsetStr + "(" + varLoc.reg + ")"})}; + } + // Memory label base — use %lo() with offset + return {asmOp(opName, {varRes.reg, + "%lo(" + varLoc.name + " + " + offsetStr + ")"})}; +} + +std::vector opStore(const VarDef &varRes, + const std::vector &varOffsets) { + if (varOffsets.empty()) + state.throwError("Store needs at least one offset argument!"); + + const auto &varLoc = varOffsets[0]; + static const std::unordered_map storeMap = { + {"u8", "sb"}, {"s8", "sb"}, {"u16", "sh"}, + {"s16", "sh"}, {"u32", "sw"}, {"s32", "sw"}, + }; + auto it = storeMap.find(toString(varRes.type)); + std::string opName = it != storeMap.end() ? it->second : "sw"; + + std::string baseReg = varLoc.reg.empty() ? "$zero" : varLoc.reg; + + // Collect offset parts + std::vector offsetParts; + bool allNumeric = true; + for (size_t i = 1; i < varOffsets.size(); ++i) { + if (!varOffsets[i].reg.empty() && varOffsets[i].reg[0] != '%') { + offsetParts.push_back(varOffsets[i].reg); + allNumeric = allNumeric && varOffsets[i].name.empty(); + } else if (!varOffsets[i].name.empty()) { + offsetParts.push_back(varOffsets[i].name); + allNumeric = false; + } else if (!varOffsets[i].reg.empty()) { + offsetParts.push_back(varOffsets[i].reg); + allNumeric = false; + } + } + // If base is a memory label, push it as an extra offset + if (varLoc.reg.empty() && !varLoc.name.empty()) { + offsetParts.push_back(varLoc.name); + allNumeric = false; + } + + std::string offset; + for (size_t i = 0; i < offsetParts.size(); ++i) { + if (i > 0) offset += " + "; + offset += offsetParts[i]; + } + if (!offset.empty() && !allNumeric) { + offset = "%lo(" + offset + ")"; + } + + return {asmOp(opName, + {varRes.reg, offset + "(" + baseReg + ")"})}; +} + +// Generic reg-or-immediate operation +static std::vector +opRegOrImmediate(const std::string &opReg, const std::string &opImm, + bool (*rangeCheck)(uint32_t), const VarDef &varRes, + const VarDef &varLeft, const VarDef &varRight) { + assertScalarVars(varLeft, &varRight); + if (!varRight.reg.empty()) { + // Label references like %lo(NAME) should use immediate form + if (varRight.reg[0] == '%') { + return {asmOp(opImm, {varRes.reg, varLeft.reg, varRight.reg})}; + } + return {asmOp(opReg, {varRes.reg, varLeft.reg, varRight.reg})}; + } + uint32_t valU32 = varRight.value; + if (rangeCheck(valU32)) { + return {asmOp(opImm, {varRes.reg, varLeft.reg, + std::to_string(valU32 & 0xFFFF)})}; + } + auto load = loadImmediate("$at", std::to_string(valU32)); + load.push_back(asmOp(opReg, {varRes.reg, varLeft.reg, "$at"})); + return load; +} + +std::vector opAdd(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + assertScalarVars(varLeft, &varRight); + return opRegOrImmediate("addu", "addiu", u32InS16Range, varRes, varLeft, + varRight); +} + +std::vector opSub(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + assertScalarVars(varLeft, &varRight); + if (!varRight.reg.empty()) { + if (varRight.reg[0] == '%') { + state.throwError("Subtraction cannot use labels!"); + } + return {asmOp("subu", {varRes.reg, varLeft.reg, varRight.reg})}; + } + VarDef negRight = varRight; + negRight.value = -varRight.value; + return opAdd(varRes, varLeft, negRight); +} + +std::vector opMul(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, bool /*clearAccum*/) { + assertScalarVars(varLeft, &varRight); + int shift = mulToShift(varRight.value); + if (!varRight.reg.empty() || shift < 0) { + state.throwError( + "Scalar-Multiplication only allowed with a power-of-two constant!"); + } + if (varRight.value == 1) { + state.throwError("Scalar-Multiplication with 1 is a NOP!"); + } + VarDef shiftVar = varRight; + shiftVar.value = shift; + return opShiftLeft(varRes, varLeft, shiftVar); +} + +std::vector opDiv(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + assertScalarVars(varLeft, &varRight); + int shift = mulToShift(varRight.value); + if (!varRight.reg.empty() || shift < 0) { + state.throwError( + "Scalar-Division only allowed with a power-of-two constant!"); + } + if (varRight.value == 1) { + state.throwError("Scalar-Division by 1 is a NOP!"); + } + VarDef shiftVar = varRight; + shiftVar.value = shift; + bool logical = varLeft.type != TypeClass::Unknown && (varLeft.type == TypeClass::U32 || varLeft.type == TypeClass::U16 || varLeft.type == TypeClass::U8); + return opShiftRight(varRes, varLeft, shiftVar, logical); +} + +std::vector opShiftLeft(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + assertScalarVars(varLeft, &varRight); + if (varRight.value < 0 || varRight.value > 31) { + state.throwError("Shift-Left value must be in range 0(varRight.value))})}; +} + +std::vector opShiftRight(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, bool logical) { + assertScalarVars(varLeft, &varRight); + if (varRight.value < 0 || varRight.value > 31) { + state.throwError("Shift-Right value must be in range 0(varRight.value)) : varRight.reg; + return {asmOp(instr, {varRes.reg, varLeft.reg, valR})}; +} + +std::vector opAnd(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + return opRegOrImmediate("and", "andi", u32InU16Range, varRes, varLeft, + varRight); +} + +std::vector opOr(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + return opRegOrImmediate("or", "ori", u32InU16Range, varRes, varLeft, + varRight); +} + +std::vector opXOR(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + return opRegOrImmediate("xor", "xori", u32InU16Range, varRes, varLeft, + varRight); +} + +std::vector opNOR(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight) { + if (varRight.reg.empty()) + state.throwError("NOR is only supported for variables!"); + return opRegOrImmediate("nor", "nori", u32InU16Range, varRes, varLeft, + varRight); +} + +std::vector opBitFlip(const VarDef &varRes, + const VarDef &varRight) { + if (varRight.reg.empty()) + state.throwError("Bitflip is only supported for variables!"); + return {asmOp("nor", {varRes.reg, "$zero", varRight.reg})}; +} + +std::vector opCompare(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, + const std::string &op, bool /*ternary*/) { + if (op == "<") { + return {asmOp("slt", {varRes.reg, varLeft.reg, varRight.reg})}; + } + if (op == ">") { + return {asmOp("slt", {varRes.reg, varRight.reg, varLeft.reg})}; + } + state.throwError("Compare op '" + op + "' not implemented yet!"); + return {}; +} + +} // namespace rspl::ops + +// TEMPORARY DEBUG +#include +static int debugStoreCount = 0; diff --git a/cpp/src/operations/scalar.h b/cpp/src/operations/scalar.h new file mode 100644 index 0000000..e5816f9 --- /dev/null +++ b/cpp/src/operations/scalar.h @@ -0,0 +1,61 @@ +#pragma once + +#include "../asm.h" +#include "../state.h" + +#include +#include + +namespace rspl::ops { + +// Load a 32-bit integer into a register with minimal instructions. +std::vector loadImmediate(const std::string ®Dst, + const std::string &value); + +// Move / assign (scalar to scalar, or immediate to scalar) +std::vector opMove(const VarDef &varRes, const VarDef &varRight); + +// Load from memory +std::vector opLoad(const VarDef &varRes, const VarOrMem &varLoc, + const VarOrMem &varOffset); + +// Store to memory +std::vector opStore(const VarDef &varRes, + const std::vector &varOffsets); + +// Arithmetic +std::vector opAdd(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opSub(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opMul(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, + bool clearAccum = true); +std::vector opDiv(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); + +// Shifts +std::vector opShiftLeft(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opShiftRight(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, bool logical); + +// Bitwise +std::vector opAnd(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opOr(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opNOR(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opXOR(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opBitFlip(const VarDef &varRes, + const VarDef &varRight); + +// Compare (scalar in if/while conditions) +std::vector opCompare(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, const std::string &op, + bool ternary); + +} // namespace rspl::ops diff --git a/cpp/src/operations/user_function.cpp b/cpp/src/operations/user_function.cpp new file mode 100644 index 0000000..503488d --- /dev/null +++ b/cpp/src/operations/user_function.cpp @@ -0,0 +1,71 @@ +#include "user_function.h" +#include "scalar.h" + +#include "../asm.h" +#include "../state.h" + +namespace rspl::ops { + +std::vector callUserFunction( + const std::string &name, const std::vector &args) { + const FuncDef *userFunc = state.getFunction(name); + if (!userFunc) { + if (!state.varExists(name)) { + state.throwError("Function " + name + " not known!"); + } + // Indirect call through register variable + static FuncDef indirect; + indirect.name = *state.getVarReg(name); + indirect.isRelative = false; + userFunc = &indirect; + } + + std::vector res; + + if (userFunc->args.size() != args.size()) { + state.throwError("Function " + name + " expects " + + std::to_string(userFunc->args.size()) + + " arguments, got " + + std::to_string(args.size()) + "!"); + } + + for (size_t i = 0; i < args.size(); ++i) { + const auto &argUser = args[i]; + const auto &argDef = userFunc->args[i]; + if (argUser.type == ArgType::Num) { + auto load = + loadImmediate(argDef.reg, argUser.value); + res.insert(res.end(), load.begin(), load.end()); + } else { + const VarDef *argVar = + state.getRequiredVar(argUser.value, "arg" + std::to_string(i)); + if (toString(argVar->type) != toString(argDef.type)) { + state.throwError("Function " + name + + " expects argument " + std::to_string(i) + + " to be of type " + toString(argDef.type) + + ", got " + toString(argVar->type) + "!"); + } + if (argVar->reg != argDef.reg) { + state.throwError("Function " + name + + " expects argument " + std::to_string(i) + + " to be in register " + argDef.reg + + ", got " + argVar->reg + "!"); + } + } + } + + bool isRelative = userFunc->isRelative; + auto annos = state.getAnnotations("Relative"); + if (!annos.empty()) isRelative = true; + + std::vector regsArg; + for (const auto &arg : userFunc->args) { + regsArg.push_back(arg.reg); + } + + res.push_back(asmFunction(userFunc->name, regsArg, isRelative)); + res.push_back(asmNOP()); + return res; +} + +} // namespace rspl::ops diff --git a/cpp/src/operations/user_function.h b/cpp/src/operations/user_function.h new file mode 100644 index 0000000..6fa0730 --- /dev/null +++ b/cpp/src/operations/user_function.h @@ -0,0 +1,14 @@ +#pragma once + +#include "../asm.h" +#include "../ast.h" + +#include +#include + +namespace rspl::ops { + +std::vector callUserFunction(const std::string &name, + const std::vector &args); + +} // namespace rspl::ops diff --git a/cpp/src/operations/vector.cpp b/cpp/src/operations/vector.cpp new file mode 100644 index 0000000..35e5911 --- /dev/null +++ b/cpp/src/operations/vector.cpp @@ -0,0 +1,1090 @@ +#include "vector.h" +#include "scalar.h" + +#include "../asm.h" +#include "../registers.h" +#include "../state.h" +#include "../swizzle.h" +#include "../types.h" + +#include +#include + +namespace rspl::ops { + +// --- Vec32 register helpers (mirror registers.js) --------------------- + +static const std::string *nextVecReg(const std::string ®Name) { + return reg::nextVecReg(regName); +} + +static std::string intReg(const VarDef &v) { return v.reg; } + +static std::string fractReg(const VarDef &v) { + if (v.type == TypeClass::Vec32 || v.originalType == TypeClass::Vec32) { + const auto *next = reg::nextVecReg(v.reg); + return next ? *next : reg::Reg::VZERO; + } + if (v.castType == CastType::Ufract || v.castType == CastType::Sfract) { + return v.reg; + } + return reg::Reg::VZERO; +} + +// Returns the two registers that represent the full vec32 value. +// For sources with a cast (ufract/sfract), the irrelevant half is VZERO. +// For destinations, callers should use getVec32DstRegs() to always get the +// original register pair. +std::pair getVec32Regs(const VarDef &v) { + if (v.type == TypeClass::Vec32) { + const auto *next = reg::nextVecReg(v.reg); + return {v.reg, next ? *next : reg::Reg::VZERO}; + } + if (v.castType == CastType::Ufract || v.castType == CastType::Sfract) { + return {reg::Reg::VZERO, v.reg}; + } + return {v.reg, reg::Reg::VZERO}; +} + +// JS getVec32RegsResLR: when the result is not a two-reg type (cast or +// vec16), the meaningful source register is used for both halves of the +// source pair so vmrg writes land correctly. +static void adjustRegsForResType(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, + std::pair ®sL, + std::pair ®sR) { + if (!isTwoRegType(varRes.type)) { + regsL.first = varLeft.reg; + regsR.first = varRight.reg; + } +} + +// Returns the actual register pair of the original vec32, regardless of cast. +static std::pair +getVec32DstRegs(const VarDef &v) { + if (v.type == TypeClass::Vec32) { + const auto *next = reg::nextVecReg(v.reg); + return {v.reg, next ? *next : reg::Reg::VZERO}; + } + if (v.originalType == TypeClass::Vec32) { + if (v.castType == CastType::Ufract || v.castType == CastType::Sfract) { + // v.reg is the fract register; the int reg is the previous one + const auto *prev = reg::nextReg(v.reg, -1); + std::string intReg = prev ? *prev : reg::Reg::VZERO; + return {intReg, v.reg}; + } + // sint/uint: v.reg is the int register + const auto *next = reg::nextVecReg(v.reg); + return {v.reg, next ? *next : reg::Reg::VZERO}; + } + // Shouldn't be reached for vec32 destinations + return {v.reg, reg::Reg::VZERO}; +} + +static void assertVectorVars(const VarDef &varLeft, + const VarDef *varRight = nullptr) { + if (!isVecType(varLeft.type) || + (varRight && !varRight->reg.empty() && + !isVecType(varRight->type))) { + state.throwError( + "Vector-Operation requires all variables to be vectors!"); + } +} + +// --- Generic logic op (vand, vor, vxor, vnor) ------------------------- + +static std::vector +genericLogicOp(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight, const std::string &op) { + std::string funcName; + for (char c : op) + funcName += static_cast(std::toupper(c)); + if (varRight.reg.empty()) + state.throwError(funcName + " cannot be done with a constant!"); + if (!varRes.swizzle.empty() || !varLeft.swizzle.empty()) + state.throwError(funcName + + " only allows swizzle on the right side!"); + assertVectorVars(varLeft, &varRight); + + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit == SWIZZLE_MAP.end()) + state.throwError("Unsupported swizzle: " + varRight.swizzle); + + bool is32 = (varRes.type == TypeClass::Vec32); + std::string swSuffix = sit->second; + std::string regR = varRight.reg + swSuffix; + + std::vector res; + res.push_back(asmOp(op, {varRes.reg, varLeft.reg, regR})); + if (is32) { + res.push_back(asmOp(op, {*reg::nextVecReg(varRes.reg), + fractReg(varLeft), fractReg(varRight) + + swSuffix})); + } + return res; +} + +// --- opMoveVec -------------------------------------------------------- + +std::vector opMoveVec(const VarDef &varRes, + const VarDef &varRight) { + bool isVec32 = (varRes.type == TypeClass::Vec32 || + varRes.originalType == TypeClass::Vec32); + + // Constant assignment to full vector + if (varRight.reg.empty() && varRes.swizzle.empty()) { + auto pIt = POW2_SWIZZLE_VAR.find(varRight.value); + if (pIt == POW2_SWIZZLE_VAR.end()) { + state.throwError( + "Can only assign a constant to a vector if it is a power of " + "two or zero!"); + } + auto sit = SWIZZLE_MAP.find(pIt->second.swizzle); + std::string regPow = pIt->second.reg + sit->second; + std::vector res; + auto regsDst = getVec32Regs(varRes); + bool hasCast = + varRes.castType != CastType::None || !varRes.swizzle.empty(); + if (hasCast) { + // For casts, only the relevant half gets the constant + if (regsDst.second != reg::Reg::VZERO) + res.push_back( + asmOp("vxor", {regsDst.second, reg::Reg::VZERO, regPow})); + } else { + res.push_back( + asmOp("vxor", {regsDst.first, reg::Reg::VZERO, regPow})); + if (isVec32 && regsDst.second != reg::Reg::VZERO) + res.push_back( + asmOp("vxor", {regsDst.second, reg::Reg::VZERO, + reg::Reg::VZERO})); + } + return res; + } + + // Scalar -> vector + bool isScalar = !varRight.reg.empty() && !isVecType(varRight.type) && + varRight.reg[0] != '%'; + if (isScalar) { + auto swizzleRes = SWIZZLE_MAP.find(varRes.swizzle); + std::string sRes = + swizzleRes != SWIZZLE_MAP.end() ? swizzleRes->second : ""; + bool needsExpand = varRes.swizzle.empty(); + if (needsExpand) + sRes = ".e0"; + + std::vector scalarRes; + if (isVec32 || varRes.originalType == TypeClass::Vec32) { + auto regsDst = getVec32DstRegs(varRes); + scalarRes.push_back( + asmOp("mtc2", {varRight.reg, regsDst.second + sRes})); + scalarRes.push_back( + asmOp("srl", {reg::Reg::AT, varRight.reg, "16"})); + scalarRes.push_back( + asmOp("mtc2", {reg::Reg::AT, regsDst.first + sRes})); + if (needsExpand) { + scalarRes.push_back( + asmOp("vor", {regsDst.first, reg::Reg::VZERO, + regsDst.first + sRes})); + scalarRes.push_back( + asmOp("vor", {regsDst.second, reg::Reg::VZERO, + regsDst.second + sRes})); + } + } else { + scalarRes.push_back( + asmOp("mtc2", {varRight.reg, varRes.reg + sRes})); + if (needsExpand) { + scalarRes.push_back( + asmOp("vor", {varRes.reg, reg::Reg::VZERO, + varRes.reg + sRes})); + } + } + return scalarRes; + } + + // Vector -> vector + if (!varRight.reg.empty()) { + auto regsDst = getVec32DstRegs(varRes); + auto regsR = getVec32Regs(varRight); + + // When both sides have casts, only use the base register of each + // (the paired register is VZERO, whose writes are filtered as NOPs). + if (varRes.castType != CastType::None && varRight.castType != CastType::None) { + regsDst = {varRes.reg, reg::Reg::VZERO}; + regsR = {varRight.reg, reg::Reg::VZERO}; + } + + // Full vector move + if (varRight.swizzle.empty()) { + std::vector moveRes; + moveRes.push_back( + asmOp("vor", {regsDst.first, reg::Reg::VZERO, regsR.first})); + if (isVec32) { + moveRes.push_back( + asmOp("vor", {regsDst.second, reg::Reg::VZERO, regsR.second})); + } + return moveRes; + } + + // Broadcast swizzle into full vector + if (varRes.swizzle.empty()) { + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + std::string suffix = + sit != SWIZZLE_MAP.end() ? sit->second : ""; + std::vector broadcastRes; + broadcastRes.push_back( + asmOp("vor", + {regsDst.first, reg::Reg::VZERO, regsR.first + suffix})); + if (isVec32) { + broadcastRes.push_back(asmOp("vor", {regsDst.second, reg::Reg::VZERO, + regsR.second + suffix})); + } + return broadcastRes; + } + + // Half-vector move: 4-lane swizzle on both sides requires + // storing through scratch memory since vmov only moves single lanes + // and sdv/ldv can move 8 bytes (4 lanes) at once. + // Examples: res.xyzw = a.XYZW (upper→lower), + // res.XYZW = a.xyzw (lower→upper) + bool isHalfMove = + !varRes.swizzle.empty() && varRes.swizzle.size() == 4 && + !varRight.swizzle.empty() && varRight.swizzle.size() == 4; + if (isHalfMove) { + auto sitSrc = SWIZZLE_SCALAR_IDX.find(varRight.swizzle[0]); + auto sitDst = SWIZZLE_SCALAR_IDX.find(varRes.swizzle[0]); + if (sitSrc != SWIZZLE_SCALAR_IDX.end() && + sitDst != SWIZZLE_SCALAR_IDX.end()) { + int srcOffset = sitSrc->second * 2; // byte offset into source reg + int dstOffset = sitDst->second * 2; // byte offset into dest reg + int accessLen = 8; // 4 lanes × 2 bytes = 8 bytes + + state.addAnnotation("Barrier", "__SCRATCH_MEM__"); + + std::vector halfRes; + halfRes.push_back( + asmOp("ori", {reg::Reg::AT, reg::Reg::ZERO, + "%lo(RSPQ_SCRATCH_MEM)"})); + halfRes.push_back( + asmOp("sdv", {regsR.first, std::to_string(srcOffset), "0", + reg::Reg::AT})); + if (regsR.second != reg::Reg::VZERO) + halfRes.push_back( + asmOp("sdv", {regsR.second, std::to_string(srcOffset), + std::to_string(accessLen), reg::Reg::AT})); + halfRes.push_back( + asmOp("ldv", {regsDst.first, std::to_string(dstOffset), "0", + reg::Reg::AT})); + if (regsDst.second != reg::Reg::VZERO) + halfRes.push_back( + asmOp("ldv", {regsDst.second, std::to_string(dstOffset), + std::to_string(accessLen), reg::Reg::AT})); + return halfRes; + } + } + + // Single lane move + auto sitRes = SWIZZLE_MAP.find(varRes.swizzle); + auto sitRight = SWIZZLE_MAP.find(varRight.swizzle); + std::string sRes = + sitRes != SWIZZLE_MAP.end() ? sitRes->second : ""; + std::string sRight = + sitRight != SWIZZLE_MAP.end() ? sitRight->second : ""; + std::vector laneRes; + laneRes.push_back( + asmOp("vmov", {regsDst.first + sRes, regsR.first + sRight})); + if (isVec32) { + laneRes.push_back(asmOp("vmov", + {regsDst.second + sRes, regsR.second + sRight})); + } + return laneRes; + } + + // Constant to single lane or scaled assignment + if (varRight.reg.empty()) { + auto sit = SWIZZLE_MAP.find(varRes.swizzle); + std::string sRes = sit != SWIZZLE_MAP.end() ? sit->second : ""; + int64_t val = varRight.value; + + // Power-of-two lookup (works for both full-vector and single-lane) + auto pIt = POW2_SWIZZLE_VAR.find(val); + if (pIt != POW2_SWIZZLE_VAR.end()) { + auto swSit = SWIZZLE_MAP.find(pIt->second.swizzle); + std::string swSuffix = + swSit != SWIZZLE_MAP.end() ? swSit->second : ""; + if (varRes.swizzle.empty()) { + // Full vector constant — use vxor with power-of-two reg + return { + asmOp("vxor", {varRes.reg, reg::Reg::VZERO, + pIt->second.reg + swSuffix})}; + } + // Single lane — use vmov from power-of-two reg + std::vector pow2Res; + auto regsDst = getVec32Regs(varRes); + pow2Res.push_back( + asmOp("vmov", {regsDst.first + sRes, pIt->second.reg + swSuffix})); + if (isVec32) { + auto zeroIt = POW2_SWIZZLE_VAR.find(0); + auto zSwSit = + zeroIt != POW2_SWIZZLE_VAR.end() + ? SWIZZLE_MAP.find(zeroIt->second.swizzle) + : SWIZZLE_MAP.end(); + std::string zSuffix = + zSwSit != SWIZZLE_MAP.end() ? zSwSit->second : ""; + pow2Res.push_back( + asmOp("vmov", {regsDst.second + sRes, + zeroIt->second.reg + zSuffix})); + } + return pow2Res; + } + + // Full vector constant (non-power-of-two) — try vxor for zero + if (varRes.swizzle.empty() && val == 0) { + auto zeroIt = POW2_SWIZZLE_VAR.find(0); + auto zSit = SWIZZLE_MAP.find(zeroIt->second.swizzle); + return {asmOp("vxor", {varRes.reg, reg::Reg::VZERO, + zeroIt->second.reg + zSit->second})}; + } + + // Float constant — convert to FP32->FP16 fixed-point parts + double dVal = varRight.value; + if (dVal != static_cast(dVal)) { + bool isFractCast = + varRes.castType == CastType::Ufract || varRes.castType == CastType::Sfract; + double scale = (varRes.castType == CastType::Sfract) ? 0.5 : 1.0; + auto valueFP32 = + static_cast(dVal * scale * 65536.0); + int64_t valInt = (valueFP32 >> 16) & 0xFFFF; + int64_t valFract = valueFP32 & 0xFFFF; + if (varRes.castType == CastType::Sfract && dVal >= 0) { + valFract = std::min(valFract, int64_t(0x7FFF)); + } + if (isFractCast) valInt = valFract; + + std::vector fload; + auto regsDst = getVec32Regs(varRes); + if (valInt != 0) { + auto li = loadImmediate(reg::Reg::AT, std::to_string(valInt)); + fload.insert(fload.end(), li.begin(), li.end()); + } + fload.push_back( + asmOp("mtc2", {valInt == 0 ? reg::Reg::ZERO : reg::Reg::AT, + regsDst.first + sRes})); + if (isVec32 || varRes.originalType == TypeClass::Vec32) { + if (valFract != 0) { + auto li = + loadImmediate(reg::Reg::AT, std::to_string(valFract)); + fload.insert(fload.end(), li.begin(), li.end()); + } + fload.push_back( + asmOp("mtc2", {valFract == 0 ? reg::Reg::ZERO : reg::Reg::AT, + regsDst.second + sRes})); + } else if (isFractCast) { + // For fract cast on vec16, only the fract value is used (already in + // valInt) + } + return fload; + } + + // Load immediate into $at and use mtc2 + auto load = + ops::loadImmediate(reg::Reg::AT, std::to_string(val)); + load.push_back( + asmOp("mtc2", {reg::Reg::AT, varRes.reg + sRes})); + // Expand to full vector if no swizzle + if (varRes.swizzle.empty()) { + load.push_back( + asmOp("vor", {varRes.reg, reg::Reg::VZERO, + varRes.reg + sRes})); + } + return load; + } + + state.throwError("Unhandled vector move case"); + return {}; +} + +// --- opLoadVec -------------------------------------------------------- + +std::vector opLoadVec(const VarDef &varRes, + const VarOrMem &varLoc, + const VarOrMem &varOffset, + const std::string &swizzle, + bool isPackedByte, bool isSigned, + bool isUnaligned) { + std::vector res; + VarOrMem loc = varLoc; + VarOrMem offs = varOffset; + + // Resolve label address into $at + if (loc.reg.empty() && !loc.name.empty()) { + auto load = loadImmediate(reg::Reg::AT, "%lo(" + loc.name + ")"); + res.insert(res.end(), load.begin(), load.end()); + loc.reg = reg::Reg::AT; + } + + // Compute destination offset from result swizzle + int destOffset = 0; + if (!varRes.swizzle.empty()) { + auto sit = SWIZZLE_SCALAR_IDX.find(varRes.swizzle[0]); + if (sit != SWIZZLE_SCALAR_IDX.end()) destOffset = sit->second * 2; + } + + bool is32 = (varRes.type == TypeClass::Vec32); + + // Detect dupeLoad and normalize swizzle (matches JS behavior) + std::string swiz = swizzle; + bool dupeLoad = (swiz == "xyzwxyzw"); + if (dupeLoad) swiz = "xyzw"; + + int accessLen = swiz.empty() ? 16 : static_cast(swiz.size()) * 2; + + // Select load instruction based on access length + static const std::unordered_map loadInstrMap = { + {2, "lsv"}, {4, "llv"}, {8, "ldv"}, {16, "lqv"}, + }; + auto lit = loadInstrMap.find(accessLen); + if (lit == loadInstrMap.end()) { + state.throwError("Invalid load access length"); + return {}; + } + std::string loadInstr = lit->second; + + // Packed byte overrides instruction + (void)isSigned; + if (isPackedByte) { + if (is32) state.throwError("Packed byte loads are not supported for 32-bit vectors!"); + loadInstr = isSigned ? "lpv" : "luv"; + destOffset /= 2; + } + + // Compute source offset from swizzle + numeric offset + int srcOffset = 0; + if (!swiz.empty()) { + auto ssi = SWIZZLE_SCALAR_IDX.find(swiz[0]); + if (ssi != SWIZZLE_SCALAR_IDX.end()) srcOffset = ssi->second * 2; + } + if (!offs.reg.empty()) srcOffset += static_cast(std::stoll(offs.reg)); + + // Alignment check for full vector loads + if (loadInstr == "lqv" && (srcOffset % 16) != 0) { + state.throwError("Invalid full vector-load offset, must be a multiple of 16, " + std::to_string(srcOffset) + " given"); + } + + std::string alignOp; + if (isUnaligned && loadInstr == "lqv") alignOp = "lrv"; + + // Emit load instruction(s) + auto emit = [&](const std::string ®, int dOff, int sOff) { + res.push_back(asmOp(loadInstr, {reg, std::to_string(dOff), std::to_string(sOff), loc.reg})); + }; + + emit(varRes.reg, destOffset, srcOffset); + if (dupeLoad) emit(varRes.reg, destOffset + 8, srcOffset); + + if (!alignOp.empty()) { + auto emitA = [&](const std::string ®, int dOff, int sOff) { + res.push_back(asmOp(alignOp, {reg, std::to_string(dOff), std::to_string(sOff), loc.reg})); + }; + emitA(varRes.reg, destOffset, srcOffset + 0x10); + if (dupeLoad) emitA(varRes.reg, destOffset + 8, srcOffset + 0x10); + } + + if (is32) { + const auto *nextRegV = reg::nextVecReg(varRes.reg); + if (nextRegV) { + emit(*nextRegV, destOffset, srcOffset + accessLen); + if (dupeLoad) emit(*nextRegV, destOffset + 8, srcOffset + accessLen); + if (!alignOp.empty()) { + auto emitA = [&](const std::string ®, int dOff, int sOff) { + res.push_back(asmOp(alignOp, {reg, std::to_string(dOff), std::to_string(sOff), loc.reg})); + }; + emitA(*nextRegV, destOffset, srcOffset + accessLen + 0x10); + if (dupeLoad) emitA(*nextRegV, destOffset + 8, srcOffset + accessLen + 0x10); + } + } + } + + return res; +} + +std::vector opLoadBytes(const VarDef &varRes, + const VarOrMem &varLoc, + const VarOrMem &varOffset, + const std::string &swizzle, + bool isSigned) { + return opLoadVec(varRes, varLoc, varOffset, swizzle, true, isSigned, + false); +} + +// --- opStoreVec ------------------------------------------------------- + +std::vector opStoreVec(const VarDef &varRes, + const std::vector &varOffsets, + bool isPackedByte, bool isSigned, + bool isUnaligned) { + if (varOffsets.empty()) + state.throwError("Vector stores need at least one offset!"); + const auto &varLoc = varOffsets[0]; + + bool is32 = (varRes.type == TypeClass::Vec32); + int accessLen = varRes.swizzle.empty() + ? 16 + : static_cast(varRes.swizzle.size()) * 2; + + std::string storeInstr; + switch (accessLen) { + case 2: storeInstr = "ssv"; break; + case 4: storeInstr = "slv"; break; + case 8: storeInstr = "sdv"; break; + case 16: storeInstr = "sqv"; break; + default: state.throwError("Invalid store access length"); return {}; + } + + int srcOffset = varRes.swizzle.empty() + ? 0 + : (SWIZZLE_SCALAR_IDX.count(varRes.swizzle[0]) + ? SWIZZLE_SCALAR_IDX.at(varRes.swizzle[0]) * 2 + : 0); + + if (isPackedByte) { + if (is32) state.throwError("Packed byte stores are not supported for 32-bit vectors!"); + if (!varRes.swizzle.empty() && varRes.swizzle.size() != 1) + state.throwError("Packed byte stores only support single-lane swizzles!"); + storeInstr = isSigned ? "spv" : "suv"; + srcOffset /= 2; + } + + std::string alignOp; + if (isUnaligned && storeInstr == "sqv") alignOp = "srv"; + + std::vector res; + + std::string baseReg = varLoc.reg; + if (varLoc.reg.empty() && !varLoc.name.empty()) { + auto load = loadImmediate(reg::Reg::AT, "%lo(" + varLoc.name + ")"); + res.insert(res.end(), load.begin(), load.end()); + baseReg = reg::Reg::AT; + } + + // Sum all offset arguments (matching JS opStore.js:280-284) + int baseOffset = 0; + for (size_t i = 1; i < varOffsets.size(); ++i) { + if (!varOffsets[i].reg.empty()) + baseOffset += std::stoi(varOffsets[i].reg); + } + + auto emit = [&](const std::string ®, int sOff, int bOff) { + res.push_back(asmOp(storeInstr, {reg, std::to_string(sOff), std::to_string(bOff), baseReg})); + }; + auto emitA = [&](const std::string ®, int sOff, int bOff) { + res.push_back(asmOp(alignOp, {reg, std::to_string(sOff), std::to_string(bOff), baseReg})); + }; + + emit(varRes.reg, srcOffset, baseOffset); + if (!alignOp.empty()) emitA(varRes.reg, srcOffset, baseOffset + 0x10); + + if (is32) { + const auto *nextRegV = reg::nextVecReg(varRes.reg); + if (nextRegV) { + emit(*nextRegV, srcOffset, baseOffset + accessLen); + if (!alignOp.empty()) emitA(*nextRegV, srcOffset, baseOffset + accessLen + 0x10); + } + } + + return res; +} + +std::vector opStoreBytes(const VarDef &varRes, + const std::vector &varOffsets, + bool isSigned) { + return opStoreVec(varRes, varOffsets, true, isSigned, false); +} + +// --- Arithmetic ------------------------------------------------------- + +std::vector opAddVec(const VarDef &varRes, + const VarDef &varLeft, + VarDef varRight) { + if (varRight.reg.empty()) { + auto pIt = POW2_SWIZZLE_VAR.find(varRight.value); + if (pIt == POW2_SWIZZLE_VAR.end()) { + state.throwError("Addition by a constant can only be done with " + "powers of two!"); + } + varRight.reg = pIt->second.reg; + varRight.swizzle = pIt->second.swizzle; + } + if (!varRes.swizzle.empty() || !varLeft.swizzle.empty()) { + state.throwError("Addition only allows swizzle on the right side!"); + } + assertVectorVars(varLeft, &varRight); + + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit == SWIZZLE_MAP.end()) + state.throwError("Unsupported swizzle: " + varRight.swizzle); + + auto regsDst = getVec32Regs(varRes); + auto regsL = getVec32Regs(varLeft); + auto regsR = getVec32Regs(varRight); + + std::string fractOp = "vaddc"; + std::string intOp = "vaddc"; + if (varRes.type == TypeClass::Vec32) { + fractOp = "vaddc"; + intOp = "vadd"; + } else if (varRes.castType != CastType::None) { + if (varRes.castType == CastType::Sfract) fractOp = "vadd"; + if (varRes.castType == CastType::Sint) intOp = "vadd"; + } + + return {asmOp(fractOp, {regsDst.second, regsL.second, + regsR.second + sit->second}), + asmOp(intOp, {regsDst.first, regsL.first, + regsR.first + sit->second})}; +} + +std::vector opSubVec(const VarDef &varRes, + const VarDef &varLeft, + VarDef varRight) { + if (varRight.reg.empty()) { + auto pIt = POW2_SWIZZLE_VAR.find(varRight.value); + if (pIt == POW2_SWIZZLE_VAR.end()) { + state.throwError( + "Subtraction by a constant can only be done with powers of " + "two!"); + } + varRight.reg = pIt->second.reg; + varRight.swizzle = pIt->second.swizzle; + } + assertVectorVars(varLeft, &varRight); + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit == SWIZZLE_MAP.end()) + state.throwError("Unsupported swizzle: " + varRight.swizzle); + + if (varRes.type == TypeClass::Vec32) { + return {asmOp("vsubc", {*reg::nextReg(varRes.reg), + fractReg(varLeft), + fractReg(varRight) + sit->second}), + asmOp("vsub", {varRes.reg, varLeft.reg, + varRight.reg + sit->second})}; + } + bool isSigned = + varRes.castType != CastType::None && (varRes.castType == CastType::Sfract || varRes.castType == CastType::Sint); + return {asmOp(isSigned ? "vsub" : "vsubc", + {varRes.reg, varLeft.reg, + varRight.reg + sit->second})}; +} + +std::vector opMulVec(const VarDef &varRes, + const VarDef &varLeft, + VarDef varRight, bool clearAccum) { + if (varRight.reg.empty()) { + auto pIt = POW2_SWIZZLE_VAR.find(varRight.value); + if (pIt == POW2_SWIZZLE_VAR.end()) { + state.throwError("Multiplication by a constant can only be done " + "with powers of two!"); + } + varRight.reg = pIt->second.reg; + varRight.swizzle = pIt->second.swizzle; + varRight.type = TypeClass::Vec16; + } + assertVectorVars(varLeft, &varRight); + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit == SWIZZLE_MAP.end()) + state.throwError("Unsupported swizzle: " + varRight.swizzle); + + std::string swSuffix = sit->second; + bool right32Bit = (varRight.type == TypeClass::Vec32); + std::string fractOp = clearAccum ? "vmudl" : "vmadl"; + std::string intOp = clearAccum ? "vmudn" : "vmadn"; + + // 16-bit multiply with cast + // JS opMul:603-608 — vec16 result special case for sfract/ufract. + // Unlike C++ this does NOT contain caseRef/sint/default paths; + // non-matching cases fall through to the general multiply paths below. + if (varRes.type == TypeClass::Vec16 && + varLeft.type == TypeClass::Vec16 && + (varLeft.castType == CastType::Sfract || varLeft.castType == CastType::Ufract) && + varRight.originalType == TypeClass::Vec32 && + (varRight.castType == CastType::Sfract || varRight.castType == CastType::Ufract)) { + std::string opMid = clearAccum ? "vmudm" : "vmadm"; + return {asmOp(opMid, {varRes.reg, varLeft.reg, + varRight.reg + swSuffix})}; + } + + // vec32 sfract result from vec32 * vec32 (JS opMul:612-619) + if (varRes.originalType == TypeClass::Vec32 && varRes.castType == CastType::Sfract && + varLeft.type == TypeClass::Vec32 && varRight.type == TypeClass::Vec32) { + return { + asmOp("vmudl", {reg::Reg::VTEMP0, fractReg(varLeft), + fractReg(varRight) + swSuffix}), + asmOp("vmadm", {reg::Reg::VTEMP0, intReg(varLeft), + fractReg(varRight) + swSuffix}), + asmOp("vmadn", {varRes.reg, fractReg(varLeft), + intReg(varRight) + swSuffix}), + }; + } + + // 16bit * 32bit multiply (JS opMul:643-649) + if (right32Bit && varRes.type == TypeClass::Vec32 && varLeft.type == TypeClass::Vec16 && + !(varLeft.castType == CastType::Sfract || varLeft.castType == CastType::Ufract)) { + auto regsDst = getVec32Regs(varRes); + return { + asmOp("vmudm", {regsDst.second, varLeft.reg, + fractReg(varRight) + swSuffix}), + asmOp("vmadh", {intReg(varRes), varLeft.reg, + intReg(varRight) + swSuffix}), + asmOp("vmadn", {regsDst.second, reg::Reg::VZERO, reg::Reg::VZERO}), + }; + } + + // Full 32-bit multiplication + if (right32Bit) { + std::vector res; + res.push_back( + asmOp(fractOp, {reg::Reg::VTEMP0, fractReg(varLeft), + fractReg(varRight) + swSuffix})); + res.push_back( + asmOp("vmadm", {reg::Reg::VTEMP0, intReg(varLeft), + fractReg(varRight) + swSuffix})); + intOp = "vmadn"; + auto regsDst = getVec32Regs(varRes); + std::string regResFract = + (regsDst.second == reg::Reg::VZERO) ? reg::Reg::VTEMP0 + : regsDst.second; + res.push_back( + asmOp(intOp, {regResFract, fractReg(varLeft), + intReg(varRight) + swSuffix})); + res.push_back( + asmOp("vmadh", {intReg(varRes), intReg(varLeft), + intReg(varRight) + swSuffix})); + return res; + } + + // Partial multiplication: s16.16 * 0.16 (fractional part of original s16.16) + // JS opMul:662-668 + bool rightSideIsFraction = + (varRight.castType == CastType::Sfract || varRight.castType == CastType::Ufract); + if (rightSideIsFraction && + (varRight.originalType == TypeClass::Vec32 || varRes.type == TypeClass::Vec32)) { + const std::string *nextReg = reg::nextVecReg(varRes.reg); + return { + asmOp(fractOp, + {*nextReg, fractReg(varLeft), varRight.reg + swSuffix}), + asmOp("vmadm", {varRes.reg, varLeft.reg, varRight.reg + swSuffix}), + asmOp("vmadn", {*nextReg, reg::Reg::VZERO, reg::Reg::VZERO}), + }; + } + + // JS opMul:671-686 — vec16 result default path + if (varRes.type == TypeClass::Vec16) { + CastType caseRef = + varLeft.castType != CastType::None ? varLeft.castType : + varRight.castType != CastType::None ? varRight.castType : + varRes.castType; + if (caseRef == CastType::Ufract || caseRef == CastType::Sfract) { + std::string op = clearAccum ? "vmul" : "vmac"; + op += (caseRef == CastType::Ufract) ? "u" : "f"; + return {asmOp(op, {varRes.reg, varLeft.reg, + varRight.reg + swSuffix})}; + } + if (varLeft.castType == CastType::Sint || varRight.castType == CastType::Sint) { + intOp = clearAccum ? "vmudh" : "vmadh"; + } + return {asmOp(intOp, {varRes.reg, varLeft.reg, + varRight.reg + swSuffix})}; + } + + // JS opMul:688-692 — general vec32 path + if (varRes.type == TypeClass::Vec32 || varRes.originalType == TypeClass::Vec32) { + auto regsDst = getVec32Regs(varRes); + std::string regResFract = + (regsDst.second == reg::Reg::VZERO) ? reg::Reg::VTEMP0 + : regsDst.second; + return { + asmOp(intOp, {regResFract, fractReg(varLeft), + intReg(varRight) + swSuffix}), + asmOp("vmadh", {intReg(varRes), intReg(varLeft), + intReg(varRight) + swSuffix})}; + } + + // 16-bit multiply (default — scalar or unhandled) + return {asmOp(intOp, {varRes.reg, varLeft.reg, + varRight.reg + swSuffix})}; +} + +// --- Shifts ----------------------------------------------------------- + +std::vector opShiftLeftVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + if (varRight.reg.empty()) { + } else { + state.throwError( + "Vector-Shift amount must be a constant!"); + } + int64_t shiftPow = 1LL << int64_t(varRight.value); + auto pIt = POW2_SWIZZLE_VAR.find(shiftPow); + if (pIt == POW2_SWIZZLE_VAR.end()) + state.throwError("Invalid shift value"); + + auto sit = SWIZZLE_MAP.find(pIt->second.swizzle); + std::string regR = pIt->second.reg + sit->second; + + // Vec32 shift + if (varRes.type == TypeClass::Vec32) { + auto regsDst = getVec32Regs(varRes); + auto regsL = getVec32Regs(varLeft); + std::string firstReg = + (regsDst.first == regsL.first) ? reg::Reg::VTEMP0 : regsDst.first; + return { + asmOp("vmudl", {firstReg, regsL.second, regR}), + asmOp("vmadn", {regsDst.first, regsL.first, regR}), + asmOp("vmudn", {regsDst.second, regsL.second, regR})}; + } + + // Vec16 result from vec32 source + if (varRes.type == TypeClass::Vec16 && varLeft.type == TypeClass::Vec32) { + auto regsL = getVec32Regs(varLeft); + return { + asmOp("vmudl", {varRes.reg, regsL.second, regR}), + asmOp("vmadn", {varRes.reg, regsL.first, regR})}; + } + + // Vec16 shift + return { + asmOp("vmudn", {varRes.reg, varLeft.reg, regR})}; +} + +std::vector opShiftRightVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, + bool logical) { + if (varRight.reg.empty()) { + } else { + state.throwError("Vector-Shift amount must be a constant!"); + } + int64_t shiftVal = + static_cast((1.0 / (1LL << int64_t(varRight.value))) * 0x10000); + auto pIt = POW2_SWIZZLE_VAR.find(shiftVal); + if (pIt == POW2_SWIZZLE_VAR.end()) + state.throwError("Invalid shift value"); + + auto sit = SWIZZLE_MAP.find(pIt->second.swizzle); + std::string regR = pIt->second.reg + sit->second; + + // Vec32 shift + if (varRes.type == TypeClass::Vec32) { + auto regsDst = getVec32Regs(varRes); + auto regsL = getVec32Regs(varLeft); + + if (regsL.second == reg::Reg::VZERO) { + std::string instMid = logical ? "vmudn" : "vmudm"; + return { + asmOp(instMid, {regsDst.first, regsL.first, regR}), + asmOp("vmadn", {regsDst.second, reg::Reg::VZERO, + reg::Reg::VZERO})}; + } + + std::string instMid = logical ? "vmadn" : "vmadm"; + return { + asmOp("vmudl", {regsDst.second, regsL.second, regR}), + asmOp(instMid, {regsDst.first, regsL.first, regR}), + asmOp("vmadn", {regsDst.second, reg::Reg::VZERO, + reg::Reg::VZERO})}; + } + + // Vec16 shift + std::string instr = logical ? "vmudl" : "vmudm"; + return {asmOp(instr, {varRes.reg, varLeft.reg, regR})}; +} + +// --- Bitwise ---------------------------------------------------------- + +std::vector opAndVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + return genericLogicOp(varRes, varLeft, varRight, "vand"); +} +std::vector opOrVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + return genericLogicOp(varRes, varLeft, varRight, "vor"); +} +std::vector opNORVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + return genericLogicOp(varRes, varLeft, varRight, "vnor"); +} +std::vector opXORVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + return genericLogicOp(varRes, varLeft, varRight, "vxor"); +} + +std::vector opBitFlipVec(const VarDef &varRes, + const VarDef &varRight) { + if (!varRight.swizzle.empty()) + state.throwError( + "NOT operator is only supported for variables!"); + VarDef zero = varRes; + zero.reg = reg::Reg::VZERO; + zero.type = TypeClass::Vec16; + return genericLogicOp(varRes, varRight, zero, "vnor"); +} + +// --- Special ---------------------------------------------------------- + +std::vector opInvertHalf(const VarDef &varRes, + const VarDef &varLeft) { + // Full vector — iterate over all 8 scalar swizzle lanes + if (varLeft.swizzle.empty() && varRes.swizzle.empty()) { + std::vector res; + for (char sw : {'x', 'y', 'z', 'w', 'X', 'Y', 'Z', 'W'}) { + std::string s(1, sw); + VarDef resCopy = varRes; + resCopy.swizzle = s; + VarDef leftCopy = varLeft; + leftCopy.swizzle = s; + auto lane = opInvertHalf(resCopy, leftCopy); + res.insert(res.end(), lane.begin(), lane.end()); + } + return res; + } + + auto sitRes = SWIZZLE_MAP.find(varRes.swizzle); + auto sitArg = SWIZZLE_MAP.find(varLeft.swizzle); + std::string sRes = + sitRes != SWIZZLE_MAP.end() ? sitRes->second : ""; + std::string sArg = + sitArg != SWIZZLE_MAP.end() ? sitArg->second : ""; + + if (varRes.type == TypeClass::Vec32 && varLeft.type == TypeClass::Vec16) { + return { + asmOp("vrcp", {fractReg(varRes) + sRes, varLeft.reg + sArg}), + asmOp("vrcph", {intReg(varRes) + sRes, varLeft.reg + sArg})}; + } + return { + asmOp("vrcph", {intReg(varRes) + sRes, intReg(varLeft) + sArg}), + asmOp("vrcpl", + {fractReg(varRes) + sRes, fractReg(varLeft) + sArg}), + asmOp("vrcph", {intReg(varRes) + sRes, + std::string(reg::Reg::VZERO) + sArg})}; +} + +std::vector opInvertSqrtHalf(const VarDef &varRes, + const VarDef &varLeft) { + // Full vector — iterate over all 8 scalar swizzle lanes + if (varLeft.swizzle.empty() && varRes.swizzle.empty()) { + std::vector res; + for (char sw : {'x', 'y', 'z', 'w', 'X', 'Y', 'Z', 'W'}) { + std::string s(1, sw); + VarDef resCopy = varRes; + resCopy.swizzle = s; + VarDef leftCopy = varLeft; + leftCopy.swizzle = s; + auto lane = opInvertSqrtHalf(resCopy, leftCopy); + res.insert(res.end(), lane.begin(), lane.end()); + } + return res; + } + + auto sitRes = SWIZZLE_MAP.find(varRes.swizzle); + auto sitArg = SWIZZLE_MAP.find(varLeft.swizzle); + std::string sRes = + sitRes != SWIZZLE_MAP.end() ? sitRes->second : ""; + std::string sArg = + sitArg != SWIZZLE_MAP.end() ? sitArg->second : ""; + + return { + asmOp("vrsqh", + {intReg(varRes) + sRes, intReg(varLeft) + sArg}), + asmOp("vrsql", + {fractReg(varRes) + sRes, fractReg(varLeft) + sArg}), + asmOp("vrsqh", + {intReg(varRes) + sRes, std::string(reg::Reg::VZERO) + + ".e0"})}; +} + +std::vector opDivVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight) { + state.throwError( + "Vector division is not supported! Use invert_half() or " + "shift-right instead."); + return {}; +} + +// --- Compare ---------------------------------------------------------- + +std::vector opCompareVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, + const std::string &op, + const ast::TernaryPart *ternary) { + if (!ternary && isTwoRegType(varRes.type)) + state.throwError("Vector comparison can only use vec16!"); + if (varLeft.type != TypeClass::Vec16) + state.throwError("Vector comparison can only use vec16!"); + if (varRight.type != TypeClass::Vec16 && !varRight.reg.empty()) + state.throwError("Vector comparison can only use vec16!"); + if (!varRes.swizzle.empty()) + state.throwError( + "Vector comparison result variable cannot use swizzle!"); + if (!varLeft.swizzle.empty()) + state.throwError( + "Vector comparison left-side cannot use swizzle!"); + + static const std::unordered_map opMap = { + {"<", "vlt"}, {"==", "veq"}, {"!=", "vne"}, {">=", "vge"}, + }; + auto it = opMap.find(op); + if (it == opMap.end()) + state.throwError("Unsupported comparison operator: " + op); + + std::string swizzleRight; + if (!varRight.swizzle.empty()) { + auto sit = SWIZZLE_MAP.find(varRight.swizzle); + if (sit != SWIZZLE_MAP.end()) swizzleRight = sit->second; + } + + std::string dstReg = ternary ? reg::Reg::VTEMP0 : varRes.reg; + std::vector res; + res.push_back(asmOp(it->second, + {dstReg, varLeft.reg, varRight.reg + swizzleRight})); + if (ternary) { + // Emit select (vmrg) for ternary + VarDef vLeft, vRight; + if (ternary->left == "VZERO") { + vLeft.reg = reg::Reg::VZERO; + vLeft.type = TypeClass::Vec16; + } else { + vLeft = state.getRequiredVarCopy(ternary->left, "ternary-left"); + } + if (ternary->rightVal.has_value()) { + auto pIt = POW2_SWIZZLE_VAR.find(ternary->rightVal.value()); + if (pIt != POW2_SWIZZLE_VAR.end()) { + vRight.reg = pIt->second.reg; + vRight.swizzle = pIt->second.swizzle; + vRight.type = TypeClass::Vec16; + } + } else if (!ternary->right.empty()) { + vRight = state.getRequiredVarCopy(ternary->right, "ternary-right"); + vRight.swizzle = ternary->swizzleRight; + } + auto regsDst = getVec32Regs(varRes); + auto regsL = getVec32Regs(vLeft); + auto regsR = getVec32Regs(vRight); + adjustRegsForResType(varRes, vLeft, vRight, regsL, regsR); + auto sit = SWIZZLE_MAP.find(vRight.swizzle); + std::string swSuffix = sit != SWIZZLE_MAP.end() ? sit->second : ""; + if (swSuffix == ".v") swSuffix = ""; + res.push_back(asmOp("vmrg", {regsDst.first, regsL.first, + regsR.first + swSuffix})); + res.push_back(asmOp("vmrg", {regsDst.second, regsL.second, + regsR.second + swSuffix})); + } + return res; +} + +} // namespace rspl::ops diff --git a/cpp/src/operations/vector.h b/cpp/src/operations/vector.h new file mode 100644 index 0000000..d9ac7a3 --- /dev/null +++ b/cpp/src/operations/vector.h @@ -0,0 +1,87 @@ +#pragma once + +#include "../asm.h" +#include "../ast.h" +#include "../state.h" + +#include +#include +#include + +namespace rspl::ops { + +// --- Vec32 helpers (used by builtins too) ----------------------------- + +std::pair getVec32Regs(const VarDef &v); + +// Move/assign (vector->vector, scalar->vector, constant->vector) +std::vector opMoveVec(const VarDef &varRes, + const VarDef &varRight); + +// Load from memory +std::vector opLoadVec(const VarDef &varRes, const VarOrMem &varLoc, + const VarOrMem &varOffset, + const std::string &swizzle, + bool isPackedByte = false, + bool isSigned = true, + bool isUnaligned = false); +std::vector opLoadBytes(const VarDef &varRes, + const VarOrMem &varLoc, + const VarOrMem &varOffset, + const std::string &swizzle, bool isSigned); + +// Store to memory +std::vector opStoreVec(const VarDef &varRes, + const std::vector &varOffsets, + bool isPackedByte = false, + bool isSigned = true, + bool isUnaligned = false); +std::vector opStoreBytes(const VarDef &varRes, + const std::vector &varOffsets, + bool isSigned); + +// Arithmetic +std::vector opAddVec(const VarDef &varRes, const VarDef &varLeft, + VarDef varRight); +std::vector opSubVec(const VarDef &varRes, const VarDef &varLeft, + VarDef varRight); +std::vector opMulVec(const VarDef &varRes, const VarDef &varLeft, + VarDef varRight, + bool clearAccum = true); + +// Shifts +std::vector opShiftLeftVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight); +std::vector opShiftRightVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, bool logical); + +// Bitwise +std::vector opAndVec(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opOrVec(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opNORVec(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opXORVec(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); +std::vector opBitFlipVec(const VarDef &varRes, + const VarDef &varRight); + +// Special +std::vector opInvertHalf(const VarDef &varRes, + const VarDef &varLeft); +std::vector opInvertSqrtHalf(const VarDef &varRes, + const VarDef &varLeft); +std::vector opDivVec(const VarDef &varRes, const VarDef &varLeft, + const VarDef &varRight); + +// Compare (vector -> stores result in varRes, optionally applies ternary) +std::vector opCompareVec(const VarDef &varRes, + const VarDef &varLeft, + const VarDef &varRight, + const std::string &op, + const ast::TernaryPart *ternary); + +} // namespace rspl::ops diff --git a/cpp/src/optimizer/asm_optimizer.cpp b/cpp/src/optimizer/asm_optimizer.cpp new file mode 100644 index 0000000..d1ba3f6 --- /dev/null +++ b/cpp/src/optimizer/asm_optimizer.cpp @@ -0,0 +1,581 @@ +#include "asm_optimizer.h" +#include "asm.h" +#include "asm_scan_deps.h" +#include "eval_cost.h" +#include "patterns/assertCompare.h" +#include "patterns/branchJump.h" +#include "patterns/commandAlias.h" +#include "patterns/dedupeImm.h" +#include "patterns/dedupeJumps.h" +#include "patterns/dedupeLabels.h" +#include "patterns/mergeSequence.h" +#include "patterns/removeDeadCode.h" +#include "patterns/tailCall.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rspl { + +// --- PRNG (matches JS LCG, thread-local for worker parallelism) ---------- + +static thread_local uint32_t seed_ = 0x41C64E6D; + +void setSeed(uint32_t s) { seed_ = s; } + +static double rand01() { + seed_ = (seed_ * 0x41C64E6D + 0x3039) & 0xFFFFFFFF; + return (seed_ >> 16) / 65536.0; +} + +static int randIndex(int maxExcl) { + seed_ = (seed_ * 0x41C64E6D + 0x3039) & 0xFFFFFFFF; + return (seed_ >> 16) % maxExcl; +} + +// --- Pattern optimization runner -------------------------------------- + +void asmOptimizePattern(AsmFunc &func) { + dedupeLabels(func); + dedupeJumps(func); + branchJump(func); + tailCall(func); + dedupeImmediate(func); + mergeSequence(func); + assertCompare(func); + removeDeadCode(func); + commandAlias(func); +} + +// --- Delay slot filling ----------------------------------------------- + +void fillDelaySlots(AsmFunc &func) { + for (size_t i = 0; i < func.asm_.size(); ++i) { + auto &inst = func.asm_[i]; + if (inst.type != AsmType::OP) continue; + if (inst.opFlags & + (OpFlag::OP_FLAG_IS_IMMOVABLE | OpFlag::OP_FLAG_IS_NOP | + OpFlag::OP_FLAG_IS_BRANCH)) + continue; + + auto reorderRange = asmGetReorderIndices(func.asm_, static_cast(i)); + + int delaySlotIdx = -1; + for (int idx : reorderRange) { + if (idx <= static_cast(i)) continue; + if (func.asm_[idx].opFlags & OpFlag::OP_FLAG_IS_NOP) { + delaySlotIdx = idx; + break; + } + } + + if (delaySlotIdx >= 0) { + func.asm_[delaySlotIdx] = std::move(inst); + func.asm_.erase(func.asm_.begin() + static_cast(i)); + --i; + } + } +} + +// ========================================================================== +// Reorder optimization (stochastic annealing, matches JS algorithm) +// ========================================================================== + +// --- Constants (matching JS) ---------------------------------------------- + +constexpr int POOL_SIZE = 8; +constexpr double PREFER_STALLS_RATE = 0.20; +constexpr double PREFER_PAIR_RATE = 0.80; +constexpr int MAX_STEPS_NO_CHANGE = 5000; +constexpr int SEARCH_VARIANT_SEARCH = 10; +constexpr int SEARCH_BACK_STEPS_FACTOR = 10; +constexpr int SEARCH_FWD_STEPS_FACTOR = 5; +constexpr int REORDER_MIN_OPS = 3; +constexpr int REORDER_MAX_OPS = 15; +constexpr int PROGRESS_LOG_INTERVAL = 500; // meta-iterations between logs + +// --- Helper functions ----------------------------------------------------- + +static AsmFunc cloneFunction(const AsmFunc &func) { + // Only asm_ is needed by reorderRound / evalFunctionCost / asmInitDeps. + // Skip copying name, type, argSize, annotations, etc. + AsmFunc copy; + copy.asm_ = func.asm_; + return copy; +} + +// Forward-declared +static std::pair generateWorseFunction(const AsmFunc &base, + int steps); + +// --- relocateElement (matches JS relocateElement) -------------------------- + +static void relocateElement(std::vector &arr, int from, int to) { + if (from == to) return; + if (arr[to].opFlags & OpFlag::OP_FLAG_IS_BRANCH) return; + bool targetIsNOP = arr[to].opFlags & OpFlag::OP_FLAG_IS_NOP; + bool sourceInDelaySlot = + (from >= 1) && (arr[from - 1].opFlags & OpFlag::OP_FLAG_IS_BRANCH); + + if (sourceInDelaySlot) { + if (targetIsNOP) { + // Replace NOP with delay-slot instruction (keep delay slot filled) + arr[to] = arr[from]; + } else { + AsmInst inst = std::move(arr[from]); + arr[from] = asmNOP(); + asmInitDep(arr[from]); + arr.insert(arr.begin() + to, std::move(inst)); + } + } else { + if (targetIsNOP) { + arr[to] = std::move(arr[from]); + arr.erase(arr.begin() + from); + } else { + AsmInst inst = std::move(arr[from]); + arr.erase(arr.begin() + from); + if (to > from) to--; + arr.insert(arr.begin() + to, std::move(inst)); + } + } +} + +// --- optimizeStep (matches JS optimizeStep) -------------------------------- + +static int optimizeStep(AsmFunc &func) { + auto sz = static_cast(func.asm_.size()); + if (sz < 2) return 0; + + int i = 0; + std::vector reorderIndices; + for (int r = 0; r < 50; ++r) { + i = randIndex(sz); + reorderIndices = asmGetReorderIndices(func.asm_, i); + if ((int)reorderIndices.size() > 1) break; + } + if ((int)reorderIndices.size() <= 1) return 0; + + int targetIdx = i; + bool foundIndex = false; + + // Prefer pairing opposite-type (vector<->scalar) unpaired instructions + if (rand01() < PREFER_PAIR_RATE) { + for (int j : reorderIndices) { + if ((func.asm_[j].opFlags & OpFlag::OP_FLAG_IS_VECTOR) != + (func.asm_[i].opFlags & OpFlag::OP_FLAG_IS_VECTOR)) { + if (!func.asm_[j].debug.paired) { + targetIdx = j; + foundIndex = true; + } + } + } + if (!foundIndex) return 0; + } + + // Prefer filling high-stall positions + if (!foundIndex && rand01() < PREFER_STALLS_RATE) { + int maxStalls = 0; + for (int j : reorderIndices) { + int stalls = func.asm_[j].debug.stall; + if (stalls > maxStalls) { + maxStalls = stalls; + targetIdx = j; + foundIndex = true; + } + } + } + + if (!foundIndex) { + while (targetIdx == i) { + targetIdx = reorderIndices[randIndex((int)reorderIndices.size())]; + } + } + + func.asm_[i].debug.reorderCount++; + relocateElement(func.asm_, i, targetIdx); + return 1; +} + +// --- reorderRound (matches JS reorderRound) -------------------------------- + +struct RoundResult { + int cost; + std::vector asm_; +}; + +// --- Phase-level timing (printed every N iterations) -------------------- + +struct PhaseTiming { + double cloneMs = 0; + double reorderMs = 0; + double depsMs = 0; + double evalMs = 0; + double dispatchMs = 0; + double resultsMs = 0; + int samples = 0; + void reset() { *this = {}; } +}; +PhaseTiming g_phaseTiming; + +// Mutable variant: takes ownership, avoids double-clone from WorkerPool. +static RoundResult reorderRoundImpl(AsmFunc func) { + int opCount = randIndex(REORDER_MAX_OPS - REORDER_MIN_OPS) + REORDER_MIN_OPS; + for (int o = 0; o < opCount; ++o) + optimizeStep(func); + // optimizeStep already keeps dep data current via asmInitDep calls inside + // relocateElement. Bulk rescan is redundant — deps are position-independent. + int cost = evalFunctionCost(func); + return {cost, std::move(func.asm_)}; +} + +static RoundResult reorderRound(const AsmFunc &baseFunc) { + return reorderRoundImpl(cloneFunction(baseFunc)); +} + +// --- generateWorseFunction (matches JS generateWorseFunction) --------------- + +static std::pair generateWorseFunction(const AsmFunc &base, + int steps) { + int maxCost = 0; + AsmFunc newWorst = cloneFunction(base); + for (int i = 0; i < steps; ++i) { + AsmFunc f = cloneFunction(base); + reorderRound(f); + reorderRound(f); + asmInitDeps(f); + int cost = evalFunctionCost(f); + if (cost > maxCost) { + newWorst = std::move(f); + maxCost = cost; + } + } + return {std::move(newWorst), maxCost}; +} + +// --- Helper: format time string from ms ----------------------------------- + +static std::string formatTimeMs(int ms) { + int h = ms / 3600000; + int m = (ms % 3600000) / 60000; + int s = (ms % 60000) / 1000; + std::ostringstream ss; + ss << std::setfill('0') << std::setw(2) << h << ":"; + ss << std::setfill('0') << std::setw(2) << m << ":"; + ss << std::setfill('0') << std::setw(2) << s; + return ss.str(); +} + +// --- Parallel variant execution ------------------------------------------- + +// Each worker runs a full variant (clone → reorderRound) independently. +// The caller dispatches a batch of N variants; all threads (including caller) +// pull from a shared index. Results go into a freshly-allocated vector so +// there's no reuse of moved-from state between calls. + +class WorkerPool { +public: + explicit WorkerPool(int numWorkers) { + for (int i = 0; i < numWorkers; ++i) + threads_.emplace_back(&WorkerPool::run, this, i); + } + + ~WorkerPool() { + { + std::lock_guard lk(mtx_); + stop_ = true; + } + cv_.notify_all(); + for (auto &t : threads_) + if (t.joinable()) t.join(); + } + + // Run `count` variants of `base` in parallel. Returns results. + std::vector runParallel(const AsmFunc &base, int count) { + results_.resize(count); + nextIdx_.store(0, std::memory_order_release); + doneCount_.store(0, std::memory_order_release); + + { + std::lock_guard lk(mtx_); + base_ = &base; + batchCount_ = count; + batchActive_ = true; + } + cv_.notify_all(); + + // Give workers a head start before caller joins + std::this_thread::yield(); + + // Wait until all tasks are completed (caller helps if any remain) + while (doneCount_.load(std::memory_order_acquire) < (size_t)count) { + workBatch(count); + } + + // Barrier: wait for all workers to exit workBatch, then cleanup + threadsDone_.store(1, std::memory_order_release); + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [&] { + return (size_t)threadsDone_.load(std::memory_order_acquire) > + threads_.size(); + }); + batchActive_ = false; + base_ = nullptr; + } + + std::vector out; + results_.swap(out); + return out; + } + +private: + std::vector threads_; + std::vector results_; + std::atomic nextIdx_{0}; + std::atomic doneCount_{0}; + std::atomic threadsDone_{0}; + const AsmFunc *base_ = nullptr; + int batchCount_ = 0; + bool batchActive_ = false; + bool stop_ = false; + std::mutex mtx_; + std::condition_variable cv_; + + void workBatch(int count) { + while (true) { + size_t idx = nextIdx_.fetch_add(1, std::memory_order_acq_rel); + if ((int)idx >= count) break; + AsmFunc variant = cloneFunction(*base_); + results_[idx] = reorderRoundImpl(std::move(variant)); + doneCount_.fetch_add(1, std::memory_order_release); + } + } + + void run(int id) { + std::random_device rd; + setSeed(rd() ^ (static_cast(id) * 0x9E3779B9)); + + while (true) { + int count; + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [&] { return stop_ || batchActive_; }); + if (stop_) return; + count = batchCount_; + } + workBatch(count); + // Signal worker finished this pass, wake caller for barrier + threadsDone_.fetch_add(1, std::memory_order_release); + cv_.notify_one(); + } + } +}; + +// --- asmOptimize (matches JS asmOptimize) ---------------------------------- + +// --- Cumulative perf counters ------------------------------------------- + +static int64_t g_totalIterations = 0; +static double g_totalWallMs = 0.0; + +void printCumulativeStats() { + if (g_totalIterations == 0) return; + double ips = g_totalWallMs > 0.0 + ? g_totalIterations / (g_totalWallMs / 1000.0) + : 0.0; + std::cerr << "\n=== Reorder Summary =======================" << std::endl; + std::cerr << " Total iterations: " << g_totalIterations << std::endl; + std::cerr << " Total wall time: " << std::fixed << std::setprecision(1) + << g_totalWallMs << " ms" << std::endl; + std::cerr << " IPS: " << std::setprecision(0) << ips << std::endl; +} + +void asmOptimize(AsmFunc &func, int maxTimeMs, int optWorkers) { + const std::string &funcName = + func.name.empty() ? "(???)" : func.name; + + asmInitDeps(func); + int costBest = evalFunctionCost(func); + func.cyclesBefore = costBest; + int costInit = costBest; + + std::cerr << "Starting optimization of '" << funcName + << "' with max. time: " << formatTimeMs(maxTimeMs) << std::endl; + + // Initialize random seed from system entropy (JS uses Math.random) + std::random_device rd; + setSeed(rd()); + + // Create worker pool (one thread per hardware core, minus calling thread) + int numWorkers = optWorkers > 0 + ? optWorkers + : std::max(1, static_cast(std::thread::hardware_concurrency()) - 1); + WorkerPool pool(numWorkers); + std::cerr << "[" << funcName << "] Worker pool: " << numWorkers + << " threads" << std::endl; + + AsmFunc lastRandPick = cloneFunction(func); + + auto startTime = std::chrono::steady_clock::now(); + auto deadline = startTime + std::chrono::milliseconds(maxTimeMs); + + int i = 0; + int metaIter = 0; + int stepsSinceLastOpt = 0; + int consecutiveSame = 0; + double totalTime = 0.0; + auto iterStart = startTime; + + while (totalTime < maxTimeMs) { + auto now = std::chrono::steady_clock::now(); + + // Progress logging + if (metaIter < 5 || (metaIter % PROGRESS_LOG_INTERVAL) == 0) { + auto dur = std::chrono::duration(now - iterStart) + .count(); + totalTime += dur; + double wallSec = + std::chrono::duration(now - startTime).count(); + double ips = wallSec > 0.0 ? i / wallSec : 0.0; + double left = maxTimeMs - totalTime; + std::cerr << "[" << funcName << "] Step: " << i + << ", Left: " << std::fixed << std::setprecision(1) << left + << "ms | Cost: " << costBest + << " | ips: " << std::setprecision(0) << ips; + + // Phase breakdown (every PROGRESS_LOG_INTERVAL meta-iterations) + if (metaIter > 0 && (metaIter % PROGRESS_LOG_INTERVAL) == 0 && + g_phaseTiming.samples > 0) { + double total = g_phaseTiming.cloneMs + g_phaseTiming.reorderMs + + g_phaseTiming.depsMs + g_phaseTiming.evalMs + + g_phaseTiming.dispatchMs + g_phaseTiming.resultsMs; + auto pct = [&](double v) { return (int)(v / total * 100); }; + std::cerr << "\n [profile] clone:" << pct(g_phaseTiming.cloneMs) + << "% reorder:" << pct(g_phaseTiming.reorderMs) + << "% deps:" << pct(g_phaseTiming.depsMs) + << "% eval:" << pct(g_phaseTiming.evalMs) + << "% dispatch:" << pct(g_phaseTiming.dispatchMs) + << "% results:" << pct(g_phaseTiming.resultsMs) + << " (samples:" << g_phaseTiming.samples << ")"; + } + std::cerr << std::endl; + iterStart = now; + } + + // Check timeout + if (now > deadline) { + double funcElapsedMs = + std::chrono::duration(now - startTime).count(); + double funcIps = + funcElapsedMs > 0.0 ? i / (funcElapsedMs / 1000.0) : 0.0; + g_totalIterations += i; + g_totalWallMs += funcElapsedMs; + std::cerr << "[" << funcName << "] Timeout after " << i + << " iterations (" << std::fixed << std::setprecision(0) + << funcIps << " ips)." << std::endl; + break; + } + + AsmFunc funcCopy = cloneFunction(func); + std::vector results; + int effectivePool = POOL_SIZE * numWorkers; + + if (stepsSinceLastOpt > MAX_STEPS_NO_CHANGE) { + ++consecutiveSame; + int stepsBack = consecutiveSame * SEARCH_BACK_STEPS_FACTOR; + int stepsFwd = consecutiveSame * SEARCH_FWD_STEPS_FACTOR; + std::cerr << "[" << funcName << "] " << stepsSinceLastOpt + << " steps since last improvement, generate new versions (" + << stepsBack << " steps backward)" << std::endl; + + // Escape local minimum: generate worse variants (sequential, each + // uses many reorderRound calls internally), then finalize in parallel. + for (int s = 0; s < SEARCH_VARIANT_SEARCH; ++s) { + auto [worseCopy, maxCost] = + generateWorseFunction(funcCopy, stepsBack); + for (int t = 0; t < stepsFwd; ++t) { + AsmFunc worseCopyTry = cloneFunction(worseCopy); + reorderRound(worseCopyTry); + asmInitDeps(worseCopyTry); + int cost = evalFunctionCost(worseCopyTry); + if (cost < maxCost) { + worseCopy.asm_ = std::move(worseCopyTry.asm_); + maxCost = cost; + } + } + // Finalize escape variant via reorderRound + AsmFunc variant = cloneFunction(worseCopy); + results.push_back(reorderRound(variant)); + } + // Remaining pool slots: if any left, run in parallel. + int remaining = effectivePool - SEARCH_VARIANT_SEARCH; + if (remaining > 0) { + auto tD0 = std::chrono::steady_clock::now(); + auto extraResults = pool.runParallel(func, remaining); + g_phaseTiming.dispatchMs += + std::chrono::duration( + std::chrono::steady_clock::now() - tD0).count(); + results.insert(results.end(), + std::make_move_iterator(extraResults.begin()), + std::make_move_iterator(extraResults.end())); + } + stepsSinceLastOpt = 0; + } else { + auto tD0 = std::chrono::steady_clock::now(); + results = pool.runParallel(func, effectivePool); + g_phaseTiming.dispatchMs += + std::chrono::duration( + std::chrono::steady_clock::now() - tD0).count(); + } + + auto tR0 = std::chrono::steady_clock::now(); + for (int s = 0; s < (int)results.size(); ++s) { + const auto &[cost, asm_] = results[s]; + // Safety: a cost of 0 means the variant is broken (no instructions or + // dependency corruption). Reject it to prevent poisoning func.asm_. + if (cost == 0) continue; + bool isBetter = cost < costBest; + bool isSame = cost == costBest; + bool canUseTheSame = s < ((int)results.size() / 4); + + if (isBetter || (canUseTheSame && isSame)) { + costBest = cost; + func.asm_ = asm_; + func.cyclesAfter = cost; + + if (isBetter) { + std::cerr << "[" << funcName << "] \033[32m**** New Best for '" + << funcName << "': " << costInit << " -> " << cost + << " ****\033[0m" << std::endl; + stepsSinceLastOpt = 0; + consecutiveSame = 0; + } + } + } + + g_phaseTiming.resultsMs += + std::chrono::duration( + std::chrono::steady_clock::now() - tR0).count(); + + if (i % 3 == 0) lastRandPick = funcCopy; + i += effectivePool; + ++metaIter; + ++stepsSinceLastOpt; + } + + func.cyclesBefore = costInit; + func.cyclesAfter = costBest; +} + +} // namespace rspl \ No newline at end of file diff --git a/cpp/src/optimizer/asm_optimizer.h b/cpp/src/optimizer/asm_optimizer.h new file mode 100644 index 0000000..d208473 --- /dev/null +++ b/cpp/src/optimizer/asm_optimizer.h @@ -0,0 +1,25 @@ +#pragma once + +#include "../asm.h" + +namespace rspl { + +/// Run pattern-based optimizations (dedupe labels, dedupe jumps, etc.) +void asmOptimizePattern(AsmFunc &func); + +/// Fill NOP delay slots by moving independent instructions forward. +/// Must be called after asmScanDeps. +void fillDelaySlots(AsmFunc &func); + +/// Set the PRNG seed for reproducible reorder results (used by tests). +void setSeed(uint32_t s); + +/// Run reorder optimization (stochastic annealing) on a single function. +/// optWorkers: 0 = auto-detect, otherwise that many threads. +void asmOptimize(AsmFunc &func, int maxTimeMs = 30'000, int optWorkers = 0); + +/// Print cumulative reorder stats (total iterations, average IPS) +/// across all asmOptimize calls since program start. +void printCumulativeStats(); + +} // namespace rspl \ No newline at end of file diff --git a/cpp/src/optimizer/asm_scan_deps.cpp b/cpp/src/optimizer/asm_scan_deps.cpp new file mode 100644 index 0000000..1975a6b --- /dev/null +++ b/cpp/src/optimizer/asm_scan_deps.cpp @@ -0,0 +1,604 @@ +#include "asm_scan_deps.h" +#include "../registers.h" +#include "../state.h" +#include "../swizzle.h" +#include "../types.h" + +#include +#include +#include +#include + +namespace rspl { + +// --- Register index lookup (295 entries, O(1) without hashing) --------- + +// Scalar register order: $zero..$ra (standard MIPS ABI, $at excluded) +// Must match the order in the REG_INDEX_MAP (256..287). +static int scalarRegIdx(const char *s) { + if (s[0] == '$') ++s; else return -1; + if (s[0] == 'z' && s[1] == 'e') return 0; // $zero + if (s[0] == 'a' && s[1] == 't') return 1; // $at + if (s[0] == 'v' && s[1] >= '0' && s[1] <= '1') return 2 + (s[1]-'0'); // $v0-1 + if (s[0] == 'a' && s[1] >= '0' && s[1] <= '3') return 4 + (s[1]-'0'); // $a0-3 + if (s[0] == 't' && s[1] >= '0' && s[1] <= '7') return 8 + (s[1]-'0'); // $t0-7 + if (s[0] == 's' && s[1] >= '0' && s[1] <= '7') return 16 + (s[1]-'0'); // $s0-7 + if (s[0] == 't' && s[1] == '8') return 24; // $t8 + if (s[0] == 't' && s[1] == '9') return 25; // $t9 + if (s[0] == 'k' && s[1] == '0') return 26; // $k0 + if (s[0] == 'k' && s[1] == '1') return 27; // $k1 + if (s[0] == 'g' && s[1] == 'p') return 28; // $gp + if (s[0] == 's' && s[1] == 'p') return 29; // $sp + if (s[0] == 'f' && s[1] == 'p') return 30; // $fp + if (s[0] == 'r' && s[1] == 'a') return 31; // $ra + return -1; +} + +// Fast register → index lookup. Returns -1 if not found. +int getRegIndex(const std::string &name) { + if (name.size() < 3 || name[0] != '$') return -1; + + // Vector register: $vNN or $vNN_L (must have two digits after 'v', + // distinguishing from scalar $v0/$v1 which have only one digit). + if (name[1] == 'v' && name.size() >= 4 && name[2] >= '0' && name[2] <= '9' && + name[3] >= '0' && name[3] <= '9') { + int vnum = (name[2] - '0') * 10 + (name[3] - '0'); + if (vnum >= 32) return -1; + int base = vnum * 8; + auto uscore = name.find('_', 4); + if (uscore == std::string::npos) return base; + int lane = name[uscore + 1] - '0'; + if (lane >= 0 && lane < 8) return base + lane; + return base; + } + + // Scalar: $zero..$ra + int si = scalarRegIdx(name.c_str()); + if (si >= 0) return 256 + si; + + // Special registers + if (name == "$vco") return 288; + if (name == "$vcc") return 289; + if (name == "$acc") return 290; + if (name == "$divOut") return 291; + if (name == "$divIn") return 292; + if (name == "$divDP") return 293; + if (name == "$vce") return 294; + + return -1; +} + +// --- Stall index lookup (64 entries) ----------------------------------- + +int getRegStallIndex(const std::string &name) { + return getRegStallIndex(name.c_str(), name.size()); +} + +int getRegStallIndex(const char *name, size_t len) { + if (len < 3 || name[0] != '$') return -1; + + // Vector register: $vNN (two digits after 'v') + if (name[1] == 'v' && len >= 4 && name[2] >= '0' && name[2] <= '9' && + name[3] >= '0' && name[3] <= '9') { + int vnum = (name[2] - '0') * 10 + (name[3] - '0'); + if (vnum >= 32) return -1; + return 32 + vnum; + } + + // Scalar + int si = scalarRegIdx(name); + if (si >= 0) return si; + return -1; +} + +// --- Hidden registers ------------------------------------------------- + +const std::unordered_map> + HIDDEN_REGS_READ = []() { + std::unordered_map> m; + m[getOpcode("vlt")] = {"$vco"}; + m[getOpcode("veq")] = {"$vco"}; + m[getOpcode("vne")] = {"$vco"}; + m[getOpcode("vge")] = {"$vco"}; + m[getOpcode("vmrg")] = {"$vcc"}; + m[getOpcode("vcl")] = {"$vco", "$vce"}; + m[getOpcode("vmacf")] = {"$acc"}; + m[getOpcode("vmacu")] = {"$acc"}; + m[getOpcode("vmadn")] = {"$acc"}; + m[getOpcode("vmadl")] = {"$acc"}; + m[getOpcode("vmadm")] = {"$acc"}; + m[getOpcode("vmadh")] = {"$acc"}; + m[getOpcode("vrndp")] = {"$acc"}; + m[getOpcode("vrndn")] = {"$acc"}; + m[getOpcode("vmacq")] = {"$acc"}; + m[getOpcode("vsar")] = {"$acc"}; + m[getOpcode("vrcph")] = {"$divOut"}; + m[getOpcode("vrsqh")] = {"$divOut"}; + m[getOpcode("vrcpl")] = {"$divIn", "$divDP"}; + m[getOpcode("vrsql")] = {"$divIn", "$divDP"}; + m[getOpcode("vadd")] = {"$vco"}; + m[getOpcode("vsub")] = {"$vco"}; + return m; + }(); + +const std::unordered_map> + HIDDEN_REGS_WRITE = []() { + std::unordered_map> m; + m[getOpcode("vlt")] = {"$vcc", "$vco", "$acc"}; + m[getOpcode("veq")] = {"$vcc", "$vco", "$acc"}; + m[getOpcode("vne")] = {"$vcc", "$vco", "$acc"}; + m[getOpcode("vge")] = {"$vcc", "$vco", "$acc"}; + m[getOpcode("vch")] = {"$vcc", "$vco", "$acc", "$vce"}; + m[getOpcode("vcr")] = {"$vcc", "$vco", "$acc", "$vce"}; + m[getOpcode("vcl")] = {"$vcc", "$vco", "$acc", "$vce"}; + m[getOpcode("vmrg")] = {"$vco", "$acc"}; + m[getOpcode("vmov")] = {"$acc"}; + m[getOpcode("vrcp")] = {"$acc", "$divOut", "$divDP"}; + m[getOpcode("vrcph")] = {"$acc", "$divIn", "$divDP"}; + m[getOpcode("vrcpl")] = {"$acc", "$divOut", "$divDP"}; + m[getOpcode("vrsq")] = {"$acc", "$divOut", "$divDP"}; + m[getOpcode("vrsqh")] = {"$acc", "$divIn", "$divDP"}; + m[getOpcode("vrsql")] = {"$acc", "$divOut", "$divDP"}; + m[getOpcode("vadd")] = {"$vco", "$acc"}; + m[getOpcode("vsub")] = {"$vco", "$acc"}; + m[getOpcode("vaddc")] = {"$vco", "$acc"}; + m[getOpcode("vsubc")] = {"$vco", "$acc"}; + m[getOpcode("vabs")] = {"$acc"}; + m[getOpcode("vand")] = {"$acc"}; + m[getOpcode("vnand")] = {"$acc"}; + m[getOpcode("vor")] = {"$acc"}; + m[getOpcode("vnor")] = {"$acc"}; + m[getOpcode("vxor")] = {"$acc"}; + m[getOpcode("vnxor")] = {"$acc"}; + m[getOpcode("vmulf")] = {"$acc"}; + m[getOpcode("vmulu")] = {"$acc"}; + m[getOpcode("vmacf")] = {"$acc"}; + m[getOpcode("vmacu")] = {"$acc"}; + m[getOpcode("vmudn")] = {"$acc"}; + m[getOpcode("vmadn")] = {"$acc"}; + m[getOpcode("vmudl")] = {"$acc"}; + m[getOpcode("vmadl")] = {"$acc"}; + m[getOpcode("vmudm")] = {"$acc"}; + m[getOpcode("vmadm")] = {"$acc"}; + m[getOpcode("vmudh")] = {"$acc"}; + m[getOpcode("vmadh")] = {"$acc"}; + m[getOpcode("vrndp")] = {"$acc"}; + m[getOpcode("vrndn")] = {"$acc"}; + m[getOpcode("vmulq")] = {"$acc"}; + m[getOpcode("vmacq")] = {"$acc"}; + return m; + }(); + +static const std::unordered_set STALL_IGNORE_REGS = { + "$vcc", "$vco", "$acc", "$vce", "$divOut", "$divIn", "$divDP"}; + +static const std::unordered_set READ_ONLY_OPS = []() { + std::unordered_set s; + for (auto *op : {"beq","bne","bgezal","bltzal","bgez","bltz", + "blez","bgtz","j","jr","jal", + "sw","sh","sb","sbv","ssv","slv","sdv", + "sqv","spv","suv","shv","sfv","stv","swv","srv", + "mtc0"}) + s.insert(getOpcode(op)); + return s; +}(); + +// --- LTV/STV register groups ------------------------------------------ + +static const std::unordered_map> + LTV_REG_MAP = { + {"$v00", + {"$v00", "$v01", "$v02", "$v03", "$v04", "$v05", "$v06", "$v07"}}, + {"$v08", + {"$v08", "$v09", "$v10", "$v11", "$v12", "$v13", "$v14", "$v15"}}, + {"$v16", + {"$v16", "$v17", "$v18", "$v19", "$v20", "$v21", "$v22", "$v23"}}, + {"$v24", + {"$v24", "$v25", "$v26", "$v27", "$v28", "$v29", "$v30", "$v31"}}, +}; + +// --- Mask utilities --------------------------------------------------- + +static bool maskIsZero(const RegMask &m) { + return m[0] == 0 && m[1] == 0 && m[2] == 0 && m[3] == 0 && m[4] == 0; +} + +static bool maskAnd(const RegMask &a, const RegMask &b) { + for (int i = 0; i < 5; ++i) + if (a[i] & b[i]) return true; + return false; +} + +static void maskOr(RegMask &dst, const RegMask &src) { + for (int i = 0; i < 5; ++i) dst[i] |= src[i]; +} + +static RegMask maskAll() { + // All bits set for all 295 registers + RegMask m = {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0x7FFFFFFFULL}; // bits 0–294 + return m; +} + +// --- Expand vector registers to lanes --------------------------------- + +// --- Precomputed register expansion cache ------------------------------ + +namespace { +struct ExpandCache { + std::unordered_map> map; + ExpandCache() { + static const std::unordered_map> laneMap = { + {"v",{0,1,2,3,4,5,6,7}},{".q0",{0,2,4,6}},{".q1",{1,3,5,7}}, + {".h0",{0,4}},{".h1",{1,5}},{".h2",{2,6}},{".h3",{3,7}}, + {".e0",{0}},{".e1",{1}},{".e2",{2}},{".e3",{3}}, + {".e4",{4}},{".e5",{5}},{".e6",{6}},{".e7",{7}}, + }; + const char *vecs[] = {"$v00","$v01","$v02","$v03","$v04","$v05","$v06", + "$v07","$v08","$v09","$v10","$v11","$v12","$v13","$v14","$v15", + "$v16","$v17","$v18","$v19","$v20","$v21","$v22","$v23", + "$v24","$v25","$v26","$v27","$v28","$v29","$v30","$v31"}; + for (auto *v : vecs) { + std::string reg(v); + auto &full = map[reg]; + for (int l = 0; l < 8; ++l) full.push_back(reg + "_" + std::to_string(l)); + for (auto &[ln, lanes] : laneMap) { + std::string key = reg + (ln[0] == '.' ? ln : ""); + if (ln == "v") { map[key] = full; continue; } + for (int l : lanes) map[key].push_back(reg + "_" + std::to_string(l)); + } + } + const char *scalars[] = {"$zero","$at","$v0","$v1","$a0","$a1","$a2","$a3", + "$t0","$t1","$t2","$t3","$t4","$t5","$t6","$t7", + "$s0","$s1","$s2","$s3","$s4","$s5","$s6","$s7", + "$t8","$t9","$k0","$k1","$gp","$sp","$fp","$ra", + "$vcc","$vco","$vce","$acc","$divOut","$divIn","$divDP"}; + for (auto *s : scalars) map[s] = {s}; + } +}; +} // namespace + +const std::vector &expandRegister(const std::string ®Name) { + static const ExpandCache cache; + auto it = cache.map.find(regName); + if (it != cache.map.end()) return it->second; + + // Fallback: swizzle patterns not in cache (e.g. $v01.xyzwXYZW → .v) + auto dotPos = regName.find('.'); + if (dotPos != std::string::npos) { + std::string reg = regName.substr(0, dotPos); + if (reg::isVecReg(reg)) { + auto sit = SWIZZLE_MAP.find(regName.substr(dotPos + 1)); + if (sit != SWIZZLE_MAP.end()) { + auto rit = cache.map.find(reg + sit->second); + if (rit != cache.map.end()) return rit->second; + } + } + } + static thread_local std::vector t_fallback; + t_fallback.clear(); + t_fallback.push_back(regName); + return t_fallback; +} + +// --- Source/target register extraction -------------------------------- + +static std::string extractRegFromArg(const std::string &arg) { + auto brIdx = arg.rfind('('); + if (brIdx != std::string::npos) { + return arg.substr(brIdx + 1, arg.size() - brIdx - 2); + } + if (!arg.empty() && arg[0] == '$') return arg; + return ""; +} + +std::vector getSourceRegs(const AsmInst &inst) { + if (inst.op == Op::JR() || inst.op == Op::MTC2() || + inst.op == Op::MTC0() || inst.op == Op::CTC2()) { + return {inst.args[0]}; + } + if ((inst.opFlags & OpFlag::OP_FLAG_IS_BRANCH) && + getOpcodeName(inst.op)[0] == 'b') { + if (inst.args.empty()) return {}; + return std::vector(inst.args.begin(), + inst.args.end() - 1); + } + if (inst.opFlags & OpFlag::OP_FLAG_IS_STORE) { + if (inst.op == Op::STV()) { + const std::string &mainReg = + inst.args.empty() ? "$v00" : inst.args[0]; + auto it = LTV_REG_MAP.find(mainReg); + if (it == LTV_REG_MAP.end()) { + state.throwError( + "Invalid base register " + mainReg + " for stv!"); + } + int row = inst.args.size() > 1 ? std::stoi(inst.args[1]) / 2 : 0; + std::vector regs; + for (int i = 0; i < 8; ++i) { + regs.push_back(it->second[i] + ".e" + + std::to_string((8 + i - row) % 8)); + } + regs.push_back(inst.args.back()); + return regs; + } + return inst.args; + } + if (inst.op == Op::J() || inst.op == Op::JAL()) { + return {inst.args[0]}; + } + std::vector res(inst.args.begin() + 1, inst.args.end()); + auto hit = HIDDEN_REGS_READ.find(inst.op); + if (hit != HIDDEN_REGS_READ.end()) { + res.insert(res.end(), hit->second.begin(), hit->second.end()); + } + return res; +} + +std::vector getTargetRegs(const AsmInst &inst) { + if (READ_ONLY_OPS.count(inst.op)) return {}; + + if ((inst.opFlags & OpFlag::OP_FLAG_IS_LOAD) && + inst.op == Op::LTV()) { + const std::string &mainReg = + inst.args.empty() ? "$v00" : inst.args[0]; + auto it = LTV_REG_MAP.find(mainReg); + if (it == LTV_REG_MAP.end()) { + state.throwError("Invalid base register " + mainReg + " for ltv!"); + } + int row = inst.args.size() > 1 ? std::stoi(inst.args[1]) / 2 : 0; + std::vector regs; + for (int i = 0; i < 8; ++i) { + regs.push_back(it->second[i] + ".e" + + std::to_string((8 + i - row) % 8)); + } + return regs; + } + + const std::string &targetReg = + (inst.op == Op::MTC2() || inst.op == Op::CTC2()) ? inst.args[1] + : inst.args[0]; + std::vector res = {targetReg}; + auto hit = HIDDEN_REGS_WRITE.find(inst.op); + if (hit != HIDDEN_REGS_WRITE.end()) { + res.insert(res.end(), hit->second.begin(), hit->second.end()); + } + return res; +} + +// --- Dependency initialization ---------------------------------------- + +void asmInitDep(AsmInst &inst) { + // Clear all dependency data + inst.depsSourceIdx.clear(); + inst.depsTargetIdx.clear(); + inst.depsStallSourceIdx.clear(); + inst.depsStallTargetIdx.clear(); + inst.depsSourceMask = {}; + inst.depsTargetMask = {}; + inst.depsStallSourceMask0 = 0; + inst.depsStallSourceMask1 = 0; + inst.depsStallTargetMask0 = 0; + inst.depsStallTargetMask1 = 0; + inst.barrierMask = 0; + + if (inst.type != AsmType::OP || (inst.opFlags & OpFlag::OP_FLAG_IS_NOP)) + return; + + // Scratch dedup: up to 16 unique raw regs before expansion, ≤ 24 after. + // Using a stack-allocated bool array indexed by register number instead + // of unordered_set avoids all intermediate heap allocations. + bool seenBase[64] = {}; // for raw base reg dedup (max 64 stall indices) + + // --- Source registers: expand → mask + idx, track bases for stalls ---- + + auto rawSrc = getSourceRegs(inst); + for (auto &r : rawSrc) { + std::string ext = extractRegFromArg(r); + if (ext.empty() || ext[0] != '$') continue; + + // Track base register for stall info (deduplicated) + auto dotPos = ext.find('.'); + int baseLen = (dotPos != std::string::npos) ? (int)dotPos : (int)ext.size(); + // Use a cheap inline comparison for STALL_IGNORE_REGS + bool isStallIgnored = + (baseLen == 4 && ext[1] == 'v' && ext[2] == 'c' && ext[3] == 'c') || + (baseLen == 4 && ext[1] == 'v' && ext[2] == 'c' && ext[3] == 'o') || + (baseLen == 4 && ext[1] == 'v' && ext[2] == 'c' && ext[3] == 'e') || + (baseLen == 4 && ext[1] == 'a' && ext[2] == 'c' && ext[3] == 'c') || + (baseLen == 6 && ext == "$divOut") || + (baseLen == 5 && ext == "$divIn") || + (baseLen == 5 && ext == "$divDP"); + if (!isStallIgnored) { + int stallIdx = getRegStallIndex(ext.c_str(), baseLen); + if (stallIdx >= 0 && !seenBase[stallIdx]) { + seenBase[stallIdx] = true; + inst.depsStallSourceIdx.push_back(stallIdx); + } + } + + // Expand raw reg to lanes, set mask + index for each + const auto &expanded = expandRegister(ext); + for (const auto &e : expanded) { + int idx = getRegIndex(e); + if (idx >= 0) { + inst.depsSourceMask[idx / 64] |= (1ULL << (idx % 64)); + inst.depsSourceIdx.push_back(idx); + } + } + } + + // --- Target registers: expand → mask + idx, track bases for stalls --- + + // Reuse seenBase for target stall dedup — clear first + for (int &s : inst.depsStallSourceIdx) seenBase[s] = false; + + auto rawTgt = getTargetRegs(inst); + for (auto &r : rawTgt) { + auto dotPos = r.find('.'); + int baseLen = (dotPos != std::string::npos) ? (int)dotPos : (int)r.size(); + bool isStallIgnored = + (baseLen == 4 && r[1] == 'v' && r[2] == 'c' && r[3] == 'c') || + (baseLen == 4 && r[1] == 'v' && r[2] == 'c' && r[3] == 'o') || + (baseLen == 4 && r[1] == 'v' && r[2] == 'c' && r[3] == 'e') || + (baseLen == 4 && r[1] == 'a' && r[2] == 'c' && r[3] == 'c') || + (baseLen == 6 && r == "$divOut") || + (baseLen == 5 && r == "$divIn") || + (baseLen == 5 && r == "$divDP"); + if (!isStallIgnored) { + int stallIdx = getRegStallIndex(r.c_str(), baseLen); + if (stallIdx >= 0 && !seenBase[stallIdx]) { + seenBase[stallIdx] = true; + inst.depsStallTargetIdx.push_back(stallIdx); + } + } + + const auto &expanded = expandRegister(r); + for (const auto &e : expanded) { + int idx = getRegIndex(e); + if (idx >= 0) { + inst.depsTargetMask[idx / 64] |= (1ULL << (idx % 64)); + inst.depsTargetIdx.push_back(idx); + } + } + } + + // --- Stall masks from the idx vectors (already built inline above) ------ + + uint64_t srcStallMask = 0, tgtStallMask = 0; + for (int idx : inst.depsStallSourceIdx) srcStallMask |= (1ULL << idx); + for (int idx : inst.depsStallTargetIdx) tgtStallMask |= (1ULL << idx); + inst.depsStallSourceMask0 = static_cast(srcStallMask); + inst.depsStallSourceMask1 = static_cast(srcStallMask >> 32); + inst.depsStallTargetMask0 = static_cast(tgtStallMask); + inst.depsStallTargetMask1 = static_cast(tgtStallMask >> 32); + + // Barrier mask from annotations + for (auto &ann : inst.cold->annotations) { + if (ann.name == "Barrier") { + inst.barrierMask |= state.getBarrierMask(ann.value); + } + } +} + +void asmInitDeps(AsmFunc &func) { + for (auto &inst : func.asm_) asmInitDep(inst); + // Block init skipped for now +} + +// --- Reorder indices -------------------------------------------------- + +static bool checkAsmBackwardDep(const AsmInst &asm_, + const AsmInst &asmPrev) { + if (asm_.type != AsmType::OP || asmPrev.type != AsmType::OP) + return true; + if (maskAnd(asmPrev.depsTargetMask, asm_.depsSourceMask)) return true; + if (maskAnd(asmPrev.depsSourceMask, asm_.depsTargetMask)) return true; + if (asm_.barrierMask & asmPrev.barrierMask) return true; + return false; +} + +static bool maskGetBit(const RegMask &m, int idx) { + return (m[idx / 64] >> (idx % 64)) & 1ULL; +} + +static void maskSetBit(RegMask &m, int idx) { + m[idx / 64] |= (1ULL << (idx % 64)); +} + +std::vector asmGetReorderIndices(const std::vector &asmList, + int i) { + const AsmInst &asm_ = asmList[i]; + if (asm_.opFlags & OpFlag::OP_FLAG_IS_IMMOVABLE) return {i}; + + int lastWrite[REG_INDEX_SIZE] = {}; + RegMask lastWriteMask = {}; + RegMask lastReadMask = {}; + + int pos = static_cast(asmList.size()); + + // --- Scan forward --- + bool isPastBranch = false; + int f; + for (f = i + 1; f < (int)asmList.size(); ++f) { + const AsmInst &asmNext = asmList[f]; + const AsmInst *prevPrev = + (f >= 2) ? &asmList[f - 2] : nullptr; + + bool isFilledBranch = + (asmNext.opFlags & OpFlag::OP_FLAG_IS_BRANCH) && + !((f + 1 < (int)asmList.size()) && + (asmList[f + 1].opFlags & OpFlag::OP_FLAG_IS_NOP)); + isPastBranch = + prevPrev && (prevPrev->opFlags & OpFlag::OP_FLAG_IS_BRANCH); + + if (isFilledBranch || isPastBranch || + checkAsmBackwardDep(asmNext, asm_)) { + pos = f; + break; + } + + // Track last writes for each register + for (int reg : asmNext.depsTargetIdx) { + lastWrite[reg] = f; + maskSetBit(lastWriteMask, reg); + } + } + + // --- Second pass: collect reads after stop point --- + int fRead = isPastBranch ? f - 2 : f; + for (; fRead < (int)asmList.size(); ++fRead) { + for (int idx = 0; idx < 5; ++idx) + lastReadMask[idx] |= asmList[fRead].depsSourceMask[idx]; + if (asmList[fRead].opFlags & OpFlag::OP_FLAG_IS_BRANCH) break; + } + + // --- Check read-after-write across the gap --- + for (int reg : asm_.depsTargetIdx) { + int lastWritePos = lastWrite[reg]; + if (lastWritePos && maskGetBit(lastReadMask, reg)) { + pos = std::min(lastWritePos, pos); + } + } + + // --- Build forward range --- + std::vector res; + for (int r = i; r <= pos - 1; ++r) res.push_back(r); + + // --- Collect registers not overwritten after us --- + RegMask writeCheckRegsMask = asm_.depsTargetMask; + for (int idx = 0; idx < 5; ++idx) + writeCheckRegsMask[idx] &= ~lastWriteMask[idx]; + + // --- Scan backward --- + for (int b = i - 1; b >= 0; --b) { + const AsmInst &asmPrev = asmList[b]; + bool stop = (b >= 1 && (asmList[b - 1].opFlags & + OpFlag::OP_FLAG_IS_BRANCH)) || + checkAsmBackwardDep(asm_, asmPrev) || + maskAnd(asmPrev.depsTargetMask, writeCheckRegsMask); + if (stop) break; + res.push_back(b); + } + + return res; +} + +void asmScanDeps(AsmFunc &func) { + for (size_t i = 0; i < func.asm_.size(); ++i) { + auto indices = asmGetReorderIndices(func.asm_, i); + if (indices.empty()) continue; + int min = *std::min_element(indices.begin(), indices.end()); + int max = *std::max_element(indices.begin(), indices.end()); + func.asm_[i].debug.reorderLineMin = + (min >= 0 && min < (int)func.asm_.size()) + ? func.asm_[min].debug.lineASM + : 0; + func.asm_[i].debug.reorderLineMax = + (max >= 0 && max < (int)func.asm_.size()) + ? func.asm_[max].debug.lineASM + : 0; + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/asm_scan_deps.h b/cpp/src/optimizer/asm_scan_deps.h new file mode 100644 index 0000000..85360b6 --- /dev/null +++ b/cpp/src/optimizer/asm_scan_deps.h @@ -0,0 +1,50 @@ +#pragma once + +#include "../asm.h" + +#include +#include +#include +#include +#include + +namespace rspl { + +// 295-bit register mask: 5 × 64-bit words +using RegMask = std::array; + +// Fast register index lookup: register name -> index (0–294), -1 if unknown +int getRegIndex(const std::string &name); +constexpr int REG_INDEX_SIZE = 295; + +// Compact stall index lookup: register name -> index (0–63), -1 if unknown +int getRegStallIndex(const std::string &name); +int getRegStallIndex(const char *name, size_t len); + +// Hidden registers (read/written implicitly by certain ops) +extern const std::unordered_map> + HIDDEN_REGS_READ; +extern const std::unordered_map> + HIDDEN_REGS_WRITE; + +// Lane expansion for vector registers +const std::vector &expandRegister(const std::string ®Name); + +// Get source/target register names for an instruction +std::vector getSourceRegs(const AsmInst &inst); +std::vector getTargetRegs(const AsmInst &inst); + +// Initialize dependency masks/indices for a single instruction +void asmInitDep(AsmInst &inst); + +// Initialize dependency data for all instructions in a function +void asmInitDeps(AsmFunc &func); + +// Get set of indices where instruction at position `i` can be safely reordered +std::vector asmGetReorderIndices(const std::vector &asmList, + int i); + +// Debug: scan and set min/max reorder info for each instruction +void asmScanDeps(AsmFunc &func); + +} // namespace rspl diff --git a/cpp/src/optimizer/eval_cost.cpp b/cpp/src/optimizer/eval_cost.cpp new file mode 100644 index 0000000..6d4b93b --- /dev/null +++ b/cpp/src/optimizer/eval_cost.cpp @@ -0,0 +1,123 @@ +#include "eval_cost.h" +#include "../asm.h" + +#include +#include +#include + +namespace rspl { + +int evalFunctionCost(AsmFunc &func) { + // Filter to only OP instructions (including NOPs). + // Thread-local to reuse allocation across evaluations. + static thread_local std::vector ops; + ops.clear(); + for (auto &inst : func.asm_) { + if (inst.type == AsmType::OP) { + ops.push_back(&inst); + } + } + if (ops.empty()) return 0; + + // regStallExpiry[r] = cycle when register r's stall expires. + // 0 means no active stall (cycle starts at 0, so expiry > 0 means active). + // Replaces the old regCycleMap[64] which stored remaining cycles and + // was fully decremented on every tick (O(64*cycles) → O(active_stalls)). + int regStallExpiry[64] = {}; + int cycle = 0; + int pc = 0; + uint64_t lastLoadPosMask = 0; + int execCount = 0; + + // Branch state: 0=none, 2=branch, 1=delay + int branchStep = 0; + bool didJump = false; + + while (pc < (int)ops.size()) { + // Resolve stalls for to-be-executed instructions. + // Find the max expiry among all source stall registers and advance + // cycle there directly — no need to loop or decrement all 64 entries. + bool resolved; + do { + resolved = true; + for (int i = 0; i < execCount && pc + i < (int)ops.size(); ++i) { + AsmInst *execOp = ops[pc + i]; + execOp->debug.paired = (execCount == 2); + + for (int src : execOp->depsStallSourceIdx) { + if (regStallExpiry[src] > cycle) { + int advance = regStallExpiry[src] - cycle; + cycle += advance; + lastLoadPosMask >>= advance; + resolved = false; + } + } + if ((lastLoadPosMask & 0b001) && + (execOp->opFlags & OpFlag::OP_FLAG_IS_MEM_STALL_STORE)) { + execOp->debug.stall++; + cycle += 1; + lastLoadPosMask >>= 1; + resolved = false; + } + } + } while (!resolved); + + // Execute + for (int i = 0; i < execCount && pc + i < (int)ops.size(); ++i) { + AsmInst *execOp = ops[pc + i]; + if (execOp->opFlags & OpFlag::OP_FLAG_IS_MEM_STALL_LOAD) + lastLoadPosMask |= 0b100; + didJump |= (execOp->opFlags & OpFlag::OP_FLAG_LIKELY_BRANCH); + + branchStep >>= 1; + if (!branchStep && (execOp->opFlags & OpFlag::OP_FLAG_IS_BRANCH)) + branchStep = 2; // BRANCH_STEP_BRANCH + + if (didJump && branchStep == 1) { // BRANCH_STEP_DELAY + cycle += 1; + lastLoadPosMask >>= 1; + didJump = false; + } + + execOp->debug.cycle = cycle; + for (int dst : execOp->depsStallTargetIdx) { + regStallExpiry[dst] = cycle + execOp->stallLatency; + } + } + + pc += execCount; + if (pc >= (int)ops.size()) break; + + AsmInst *op = ops[pc]; + AsmInst *opNext = (pc + 1 < (int)ops.size()) ? ops[pc + 1] : nullptr; + + bool canDualIssue = + opNext && + ((op->opFlags & OpFlag::OP_FLAG_IS_VECTOR) != + (opNext->opFlags & OpFlag::OP_FLAG_IS_VECTOR)) && + !(branchStep == 2) && !(op->opFlags & OpFlag::OP_FLAG_IS_BRANCH) && + (op->depsStallTargetMask0 & opNext->depsStallSourceMask0) == 0 && + (op->depsStallTargetMask1 & opNext->depsStallSourceMask1) == 0 && + (op->depsStallTargetMask0 & opNext->depsStallTargetMask0) == 0 && + (op->depsStallTargetMask1 & opNext->depsStallTargetMask1) == 0; + + // CFC2/CTC2: prevent dual-issue if the current instruction writes + // to any register that the CFC2/CTC2 reads or writes (incl. ctrl regs). + if (canDualIssue && (opNext->opFlags & OpFlag::OP_FLAG_CTC2_CFC2)) { + for (int i = 0; i < 5; ++i) { + if ((op->depsTargetMask[i] & opNext->depsSourceMask[i]) || + (op->depsTargetMask[i] & opNext->depsTargetMask[i])) { + canDualIssue = false; + break; + } + } + } + + execCount = canDualIssue ? 2 : 1; + cycle += 1; + lastLoadPosMask >>= 1; + } + return cycle; +} + +} // namespace rspl \ No newline at end of file diff --git a/cpp/src/optimizer/eval_cost.h b/cpp/src/optimizer/eval_cost.h new file mode 100644 index 0000000..b877d41 --- /dev/null +++ b/cpp/src/optimizer/eval_cost.h @@ -0,0 +1,10 @@ +#pragma once + +namespace rspl { + +struct AsmFunc; + +/// Estimate the cycle cost of a function. Used for optimizer comparison. +int evalFunctionCost(AsmFunc &func); + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/assertCompare.h b/cpp/src/optimizer/patterns/assertCompare.h new file mode 100644 index 0000000..54e16ed --- /dev/null +++ b/cpp/src/optimizer/patterns/assertCompare.h @@ -0,0 +1,42 @@ +#pragma once +#include "asm.h" +#include "optimizer/asm_optimizer.h" +#include "operations/branch.h" +#include "registers.h" +#include + +namespace rspl { + +inline void assertCompare(AsmFunc &func) { + static const std::string LABEL_ASSERT = "assertion_failed"; + for (size_t i = 0; i + 5 < func.asm_.size(); ++i) { + auto &b = func.asm_[i]; + if (!(b.opFlags & OpFlag::OP_FLAG_IS_BRANCH)) + continue; + if (b.op == Op::J() || b.op == Op::JR() || b.op == Op::JAL()) + continue; + if (!(func.asm_[i + 1].opFlags & OpFlag::OP_FLAG_IS_NOP)) continue; + if (func.asm_[i + 2].op != Op::LUI() || + func.asm_[i + 2].args[0] != reg::Reg::AT) + continue; + if (func.asm_[i + 3].op != Op::J() || + func.asm_[i + 3].args[0] != LABEL_ASSERT) + continue; + if (!(func.asm_[i + 4].opFlags & OpFlag::OP_FLAG_IS_NOP)) continue; + if (func.asm_[i + 5].type != AsmType::LABEL) continue; + if (func.asm_[i + 5].cold->label != b.args.back()) continue; + + b.op = ops::invertBranchOp(b.op); + b.args.back() = LABEL_ASSERT; + b.cold->labelEnd = LABEL_ASSERT; + + AsmInst luiOp = std::move(func.asm_[i + 2]); + func.asm_.insert(func.asm_.begin() + static_cast(i) + 1, + std::move(luiOp)); + func.asm_.erase(func.asm_.begin() + static_cast(i) + 2, + func.asm_.begin() + static_cast(i) + 6); + i += 2; + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/branchJump.h b/cpp/src/optimizer/patterns/branchJump.h new file mode 100644 index 0000000..30d3863 --- /dev/null +++ b/cpp/src/optimizer/patterns/branchJump.h @@ -0,0 +1,55 @@ +#pragma once +#include "asm.h" +#include "optimizer/asm_optimizer.h" +#include "operations/branch.h" +#include "registers.h" +#include + +namespace rspl { + +inline void branchJump(AsmFunc &func) { + for (size_t i = 0; i + 4 < func.asm_.size(); ++i) { + auto &b = func.asm_[i]; + if (!(b.opFlags & OpFlag::OP_FLAG_IS_BRANCH) || b.cold->labelEnd.empty()) + continue; + if (b.op == Op::J() || b.op == Op::JAL()) continue; + if (func.asm_[i + 1].op != Op::NOP()) continue; + Opcode jumpOp = func.asm_[i + 2].op; + if (jumpOp != Op::J() && jumpOp != Op::JAL()) continue; + std::string realTarget = func.asm_[i + 2].args[0]; + if (func.asm_[i + 3].op != Op::NOP()) continue; + if (func.asm_[i + 4].type != AsmType::LABEL) continue; + if (func.asm_[i + 4].cold->label != b.cold->labelEnd) continue; + + std::string tempLabel = b.cold->labelEnd; + bool labelUsed = false; + for (const auto &inst : func.asm_) { + if (&inst == &b) continue; + if (inst.cold->labelEnd == tempLabel) { labelUsed = true; break; } + for (const auto &arg : inst.args) + if (arg == tempLabel) { labelUsed = true; break; } + if (labelUsed) break; + } + + b.cold->labelEnd = realTarget; + b.args.back() = realTarget; + Opcode newOp = ops::invertBranchOp(b.op); + b.op = newOp; + + if (jumpOp == Op::JAL()) { + func.asm_.erase(func.asm_.begin() + i + 2, + func.asm_.begin() + i + 4); + func.asm_.insert(func.asm_.begin() + static_cast(i), + asmOp("ori", + {reg::Reg::RA, reg::Reg::ZERO, tempLabel})); + } else if (labelUsed) { + func.asm_.erase(func.asm_.begin() + i + 2, + func.asm_.begin() + i + 4); + } else { + func.asm_.erase(func.asm_.begin() + i + 2, + func.asm_.begin() + i + 5); + } + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/commandAlias.h b/cpp/src/optimizer/patterns/commandAlias.h new file mode 100644 index 0000000..48c3a76 --- /dev/null +++ b/cpp/src/optimizer/patterns/commandAlias.h @@ -0,0 +1,23 @@ +#pragma once +#include "optimizer/asm_optimizer.h" + +namespace rspl { + +inline void commandAlias(AsmFunc &func) { + if (func.asm_.size() < 2 || func.type != FuncType::Command) return; + + auto &inst = func.asm_; + Opcode op0 = inst[0].op; + bool isBranch = + (op0 == Op::J() || op0 == Op::JR() || op0 == Op::BEQ() || op0 == Op::BNE()); + if (isBranch && (inst[1].opFlags & OpFlag::OP_FLAG_IS_NOP)) { + if (!inst[0].args.empty()) { + func.nameOverride = inst[0].args[0]; + } + if (inst.size() == 2) { + inst.clear(); + } + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/dedupeImm.h b/cpp/src/optimizer/patterns/dedupeImm.h new file mode 100644 index 0000000..95de1a8 --- /dev/null +++ b/cpp/src/optimizer/patterns/dedupeImm.h @@ -0,0 +1,41 @@ +#pragma once +#include "optimizer/asm_optimizer.h" +#include "registers.h" +#include + +namespace rspl { + +inline void dedupeImmediate(AsmFunc &func) { + // Ported from JS dedupeImm.js: track the last value written to $at + // via `ori` and remove redundant `ori $at, $zero, SAME_VALUE`. + std::string lastAT; + std::vector asmNew; + for (auto &asm_ : func.asm_) { + bool keep = true; + if (asm_.type == AsmType::OP) { + if (asm_.opFlags & OpFlag::OP_FLAG_IS_BRANCH) + lastAT.clear(); + + // Check if this instruction writes to $at + if (!asm_.args.empty() && asm_.args[0] == reg::Reg::AT) { + std::string newAT; + // Only handle "ori" — other writes are assumed to set $at + // in unknown ways and reset the cache. + if (asm_.op == Op::ORI() && asm_.args.size() >= 3) { + newAT = asm_.args[2]; + if (!lastAT.empty() && lastAT == newAT) { + keep = false; // redundant — same value already in $at + } + } + lastAT = newAT; + } + } else { + lastAT.clear(); + } + + if (keep) asmNew.push_back(std::move(asm_)); + } + func.asm_ = std::move(asmNew); +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/dedupeJumps.h b/cpp/src/optimizer/patterns/dedupeJumps.h new file mode 100644 index 0000000..f953a41 --- /dev/null +++ b/cpp/src/optimizer/patterns/dedupeJumps.h @@ -0,0 +1,43 @@ +#pragma once +#include "optimizer/asm_optimizer.h" +#include +#include +#include + +namespace rspl { + +inline void dedupeJumps(AsmFunc &func) { + std::vector> labelReplace; + for (size_t i = 0; i < func.asm_.size(); ++i) { + if (func.asm_[i].type == AsmType::LABEL) { + if (i + 1 < func.asm_.size() && func.asm_[i + 1].op == Op::J()) { + labelReplace.push_back( + {func.asm_[i].cold->label, func.asm_[i + 1].args[0]}); + if (i >= 2 && func.asm_[i - 2].op == Op::J() && + func.asm_[i - 1].type == AsmType::OP && + func.asm_[i - 1].op == Op::NOP()) { + func.asm_.erase(func.asm_.begin() + i, + func.asm_.begin() + i + 3); + --i; + } + } + } + } + + for (auto &inst : func.asm_) { + if ((inst.opFlags & OpFlag::OP_FLAG_IS_BRANCH) || inst.op == Op::J() || + inst.op == Op::JAL()) { + if (inst.args.empty()) continue; + std::string &label = inst.args.back(); + for (const auto &[oldL, newL] : labelReplace) { + if (label == oldL) label = newL; + } + if (inst.cold->labelEnd.empty()) continue; + for (const auto &[oldL, newL] : labelReplace) { + if (inst.cold->labelEnd == oldL) inst.cold->labelEnd = newL; + } + } + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/dedupeLabels.h b/cpp/src/optimizer/patterns/dedupeLabels.h new file mode 100644 index 0000000..752c679 --- /dev/null +++ b/cpp/src/optimizer/patterns/dedupeLabels.h @@ -0,0 +1,28 @@ +#pragma once +#include "optimizer/asm_optimizer.h" +#include + +namespace rspl { + +inline void dedupeLabels(AsmFunc &func) { + for (size_t i = 0; i + 1 < func.asm_.size(); ++i) { + auto &a = func.asm_[i]; + auto &b = func.asm_[i + 1]; + // Skip __-prefixed labels — these are compiler-internal and should + // never be deduplicated (matching JS dedupeLabels.js:22). + if (a.type != AsmType::LABEL || b.type != AsmType::LABEL) continue; + if (a.cold->label.starts_with("__") || b.cold->label.starts_with("__")) continue; + std::string from = a.cold->label; + std::string to = b.cold->label; + for (auto &inst : func.asm_) { + if (inst.cold->labelEnd == from) inst.cold->labelEnd = to; + for (auto &arg : inst.args) { + if (arg == from) arg = to; + } + } + func.asm_.erase(func.asm_.begin() + i); + --i; + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/mergeSequence.h b/cpp/src/optimizer/patterns/mergeSequence.h new file mode 100644 index 0000000..e673cdc --- /dev/null +++ b/cpp/src/optimizer/patterns/mergeSequence.h @@ -0,0 +1,61 @@ +#pragma once +#include "optimizer/asm_optimizer.h" +#include "registers.h" +#include + +namespace rspl { + +inline void mergeSequence(AsmFunc &func) { + for (size_t i = 0; i + 1 < func.asm_.size(); ++i) { + auto &a = func.asm_[i]; + auto &b = func.asm_[i + 1]; + + // Merge: addiu $x, $zero, N -> addu $y, $x, $z into addiu $y, $z, N + if (a.op == Op::ADDIU() && a.args.size() >= 3 && a.args[1] == "$zero" && + b.op == Op::ADDU() && b.args.size() >= 3 && b.args[1] == a.args[0]) { + b.op = Op::ADDIU(); + b.args[1] = b.args[2]; + b.args[2] = a.args[2]; + func.asm_.erase(func.asm_.begin() + i); + --i; + continue; + } + + // Merge consecutive sqrt/reciprocal: first step against VZERO can be + // combined with the start of the next sequence (JS mergeSequence.js:23-41) + if (i + 2 < func.asm_.size() && + ((a.op == Op::VRSQH() || a.op == Op::VRCPH()) && + a.args.size() >= 2 && + a.args[1].starts_with(reg::Reg::VZERO) && + func.asm_[i + 1].op == a.op && + func.asm_[i + 2].op == + (a.op == Op::VRSQH() ? Op::VRSQL() : Op::VRCPL()))) { + func.asm_[i + 1].args[0] = a.args[0]; + func.asm_.erase(func.asm_.begin() + i); + --i; + continue; + } + + // Indirect multiply by zero. + // vxor $reg, $v00, $v00.?? -> ... -> vmudl $reg, $reg, ... + // Replace $reg source in vmudl with $v00 and remove the vxor. + if (a.op == Op::VXOR() && a.args.size() >= 3 && + a.args[1] == reg::Reg::VZERO && + a.args[2].starts_with(reg::Reg::VZERO)) { + std::string targetReg = a.args[0]; + // Search forward for a matching vmudl (within 5 instructions) + for (size_t j = i + 1; j < func.asm_.size() && j <= i + 5; ++j) { + auto &vmudl = func.asm_[j]; + if (vmudl.op == Op::VMUDL() && vmudl.args.size() >= 2 && + vmudl.args[0] == targetReg && vmudl.args[1] == targetReg) { + vmudl.args[1] = reg::Reg::VZERO; + func.asm_.erase(func.asm_.begin() + i); + --i; + break; + } + } + } + } +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/removeDeadCode.h b/cpp/src/optimizer/patterns/removeDeadCode.h new file mode 100644 index 0000000..7bcc188 --- /dev/null +++ b/cpp/src/optimizer/patterns/removeDeadCode.h @@ -0,0 +1,22 @@ +#pragma once +#include "optimizer/asm_optimizer.h" + +namespace rspl { + +inline void removeDeadCode(AsmFunc &func) { + if (func.asm_.empty()) return; + int lastSafeIndex = -1; + for (int i = static_cast(func.asm_.size()) - 1 - 2; i >= 0; --i) { + const auto &inst = func.asm_[i]; + if (inst.op == Op::J() || inst.op == Op::JR()) { + lastSafeIndex = i; + break; + } + if (inst.opFlags & OpFlag::OP_FLAG_IS_NOP) continue; + break; + } + if (lastSafeIndex < 0) return; + func.asm_.erase(func.asm_.begin() + lastSafeIndex + 2, func.asm_.end()); +} + +} // namespace rspl diff --git a/cpp/src/optimizer/patterns/tailCall.h b/cpp/src/optimizer/patterns/tailCall.h new file mode 100644 index 0000000..4645c3b --- /dev/null +++ b/cpp/src/optimizer/patterns/tailCall.h @@ -0,0 +1,44 @@ +#pragma once +#include "optimizer/asm_optimizer.h" + +namespace rspl { + +inline void tailCall(AsmFunc &func) { + // Only applies to commands (JS tailCall.js:24). Match JS behaviour: + // scan for the first jal that can be converted and stop. + if (func.type != FuncType::Command) return; + + for (size_t i = 0; i + 3 < func.asm_.size(); ++i) { + bool matched = false; + + // Pattern 1: jal X; nop; jr $ra; nop → j X; nop + if (func.asm_[i].op == Op::JAL() && func.asm_[i + 1].op == Op::NOP() && + func.asm_[i + 2].op == Op::JR() && + func.asm_[i + 2].args.size() >= 1 && + func.asm_[i + 2].args[0] == "$ra" && + func.asm_[i + 3].op == Op::NOP()) { + matched = true; + } + // Pattern 2: jal X; nop; j RSPQ_Loop; nop → j X; nop + else if (func.asm_[i].op == Op::JAL() && func.asm_[i + 1].op == Op::NOP() && + func.asm_[i + 2].op == Op::J() && + func.asm_[i + 2].args.size() >= 1 && + func.asm_[i + 2].args[0] == "RSPQ_Loop" && + func.asm_[i + 3].op == Op::NOP()) { + matched = true; + } + + if (matched) { + func.asm_[i].op = Op::J(); + func.asm_.erase(func.asm_.begin() + i + 2, + func.asm_.begin() + i + 4); + } + + // If we found a jal at all (whether converted or not), stop. + // JS: "if we encounter a jump, but the above condition is not met, + // we can stop — otherwise it would mean the return register was changed." + if (func.asm_[i].op == Op::J() || func.asm_[i].op == Op::JAL()) return; + } +} + +} // namespace rspl diff --git a/cpp/src/pipeline.cpp b/cpp/src/pipeline.cpp new file mode 100644 index 0000000..263b7b0 --- /dev/null +++ b/cpp/src/pipeline.cpp @@ -0,0 +1,204 @@ +#include "pipeline.h" + +#include "asm_normalize.h" +#include "asm_writer.h" +#include "ast.h" +#include "ast2asm.h" +#include "preproc.h" +#include "optimizer/asm_optimizer.h" +#include "optimizer/asm_scan_deps.h" +#include "optimizer/eval_cost.h" +#include "state.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rspl { + +// --- JS parser subprocess ------------------------------------------- + +static std::string execJsParser(const std::string &rsplPath, + bool skipPreproc) { + const char *scriptPath = std::getenv("RSPL_PARSE_JS"); + std::string cmd; + if (scriptPath) { + cmd = std::string("node ") + scriptPath; + } else { + cmd = "node scripts/parse.js"; + } + cmd += skipPreproc ? " --preprocessed " : " "; + cmd += "\"" + rsplPath + "\""; + cmd += " 2>&1"; // capture stderr too + + FILE *pipe = popen(cmd.c_str(), "r"); + if (!pipe) { + throw std::runtime_error("Error: cannot start JS parser"); + } + std::string result; + char buf[4096]; + while (fgets(buf, sizeof(buf), pipe)) result += buf; + int rc = pclose(pipe); + if (rc != 0) { + throw std::runtime_error("Error: JS parser exited with code " + + std::to_string(rc) + "\n" + result); + } + return result; +} + +// --- runPipeline (CLI path) ----------------------------------------- + +TranspileResult runPipeline(const std::string &astJson, + const TranspileConfig &config) { + auto prog = ast::parseJson(astJson); + + auto functions = ast2asm(prog); + + if (config.optimize) { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + asmOptimizePattern(fn); + asmInitDeps(fn); + evalFunctionCost(fn); + } + if (config.reorder) { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + asmOptimize(fn, config.optimizeTime, config.optWorkers); + } + printCumulativeStats(); + } else { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + fillDelaySlots(fn); + evalFunctionCost(fn); + } + } + } + + WriteConfig wConfig; + wConfig.rspqWrapper = config.rspqWrapper; + wConfig.debugInfo = true; + + auto result = writeASM(prog, functions, wConfig); + + TranspileResult out; + out.asm_ = result.asm_; + out.sizeDMEM = result.sizeDMEM; + out.sizeIMEM = result.sizeIMEM; + return out; +} + +// --- transpileSource (test / library path) -------------------------- + +TranspileResult transpileSource(const std::string &source, + const TranspileConfig &config) { + // Preprocess in C++ to collect defines (ordered by source appearance) + std::unordered_map defines; + std::vector defineOrder; + std::string preprocessed = + preprocFull(source, defines, config.sourceDir, &defineOrder); + + // Populate source lines from the PREPROCESSED source for debug info. + // AST line numbers come from the preprocessed text (includes expanded, + // macros resolved), so the sourceLines must match. + state.sourceLines.clear(); + std::istringstream srcStream(preprocessed); + std::string srcLine; + while (std::getline(srcStream, srcLine)) { + size_t start = srcLine.find_first_not_of(" \t\r"); + size_t end = srcLine.find_last_not_of(" \t\r"); + if (start != std::string::npos) + state.sourceLines.push_back(srcLine.substr(start, end - start + 1)); + else + state.sourceLines.push_back(""); + } + + // Write preprocessed source to temp file + std::string tmpPath = "/tmp/rspl_test_source.rspl"; + { + std::ofstream f(tmpPath); + if (!f) throw std::runtime_error("Cannot write temp file"); + f << preprocessed; + } + + // Call JS parser (skip its own preprocessor since we already did it) + std::string astJson = execJsParser(tmpPath, true); + + // Parse AST + auto prog = ast::parseJson(astJson); + + // Transfer collected defines to the program in source order. + // Filter out defines that were later #undef'd (still in the map). + for (const auto &def : defineOrder) { + if (defines.count(def.name)) + prog.defines.push_back({def.name, def.value}); + } + + // Generate ASM + auto functions = ast2asm(prog); + + // Match JS pipeline: writeASM runs before patterns to advance state.line + // so that optimizer-generated instructions (e.g. branchJump's ori $ra) + // pick up ASM output line numbers instead of stale source line numbers. + if (config.optimize || config.debugInfo) { + WriteConfig wCfg; + wCfg.rspqWrapper = config.rspqWrapper; + wCfg.debugInfo = config.debugInfo; + writeASM(prog, functions, wCfg); + } + + if (config.optimize) { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + asmOptimizePattern(fn); + asmInitDeps(fn); + evalFunctionCost(fn); + } + if (config.reorder) { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + asmOptimize(fn, config.optimizeTime, config.optWorkers); + } + printCumulativeStats(); + } else { + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + fillDelaySlots(fn); + evalFunctionCost(fn); + } + } + } else if (config.debugInfo) { + // When debugInfo is on but optimize is off, still run pattern + // optimizations and cycle evaluation so the debug output contains + // meaningful cycle counts. + for (auto &fn : functions) { + if (fn.asm_.empty()) continue; + asmOptimizePattern(fn); + asmInitDeps(fn); + evalFunctionCost(fn); + } + } + + TranspileResult result; + + WriteConfig wConfig; + wConfig.rspqWrapper = config.rspqWrapper; + wConfig.debugInfo = config.debugInfo; + + auto writeResult = writeASM(prog, functions, wConfig); + result.asm_ = writeResult.asm_; + result.sizeDMEM = writeResult.sizeDMEM; + result.sizeIMEM = writeResult.sizeIMEM; + while (!result.asm_.empty() && result.asm_.back() == '\n') + result.asm_.pop_back(); + + return result; +} + +} // namespace rspl diff --git a/cpp/src/pipeline.h b/cpp/src/pipeline.h new file mode 100644 index 0000000..969d493 --- /dev/null +++ b/cpp/src/pipeline.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace rspl { + +struct TranspileConfig { + bool rspqWrapper = true; + bool optimize = false; + bool debugInfo = false; + bool reorder = false; + int optimizeTime = 30000; // ms, default 30s matching CLI + int optWorkers = 0; // 0 = auto (hw threads - 1) + std::string sourceDir = "."; +}; + +struct TranspileResult { + std::string asm_; + std::string warn; + std::string info; + int sizeDMEM = 0; + int sizeIMEM = 0; +}; + +/// Run the full transpile pipeline on a JSON AST from the JS parser. +/// Returns the transpile result. Throws std::runtime_error on errors. +TranspileResult runPipeline(const std::string &astJson, + const TranspileConfig &config = {}); + +/// Transpile an RSPL source string to assembly. +/// Handles the JS parser subprocess internally. +/// Throws std::runtime_error on parse/compile errors. +TranspileResult transpileSource(const std::string &source, + const TranspileConfig &config = {}); + +} // namespace rspl diff --git a/cpp/src/preproc.cpp b/cpp/src/preproc.cpp new file mode 100644 index 0000000..e88e213 --- /dev/null +++ b/cpp/src/preproc.cpp @@ -0,0 +1,162 @@ +#include "preproc.h" + +#include +#include +#include +#include + +namespace rspl { + +std::string stripComments(const std::string &source) { + std::istringstream iss(source); + std::string line; + std::string result; + + while (std::getline(iss, line)) { + // Remove // comments + auto pos = line.find("//"); + if (pos != std::string::npos) { + line = line.substr(0, pos); + } + result += line + "\n"; + } + + // Remove /* */ block comments using a simple state machine + // (std::regex [\s\S] is not portable in C++) + std::string tmp; + bool inBlock = false; + for (size_t i = 0; i < result.size(); ++i) { + if (!inBlock && i + 1 < result.size() && result[i] == '/' && + result[i + 1] == '*') { + inBlock = true; + ++i; // skip * + continue; + } + if (inBlock && i + 1 < result.size() && result[i] == '*' && + result[i + 1] == '/') { + inBlock = false; + ++i; // skip / + continue; + } + if (!inBlock) tmp += result[i]; + else if (result[i] == '\n') tmp += '\n'; // preserve newlines + } + return tmp; +} + +std::string preprocess(const std::string &src, + std::unordered_map &defines, + const std::string &sourceDir, + std::vector *defineOrder) { + std::istringstream iss(src); + std::string line; + std::string result; + bool insideIfdef = false; + bool ignoreLine = false; + int lineNum = 0; + + auto replaceDefines = [&](std::string l) -> std::string { + for (const auto &[name, entry] : defines) { + std::string patStr = + "(\\$\\{" + name + "\\})|(\\b" + name + "\\b)"; + l = std::regex_replace(l, std::regex(patStr), entry.value); + } + return l; + }; + + while (std::getline(iss, line)) { + ++lineNum; + std::string trimmed = line; + size_t firstNonSpace = trimmed.find_first_not_of(" \t"); + if (firstNonSpace != std::string::npos) { + trimmed = trimmed.substr(firstNonSpace); + } else { + trimmed.clear(); + } + std::string newLine; + + if (!ignoreLine && trimmed.starts_with("#define")) { + std::regex defRe("#define\\s+([a-zA-Z0-9_]+)\\s+(.*)"); + std::smatch m; + if (!std::regex_match(trimmed, m, defRe)) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Invalid #define statement!"); + } + std::string name = m[1].str(); + std::string value = replaceDefines(m[2].str()); + defines[name] = {name, value}; + if (defineOrder) defineOrder->push_back({name, value}); + } else if (!ignoreLine && trimmed.starts_with("#undef")) { + std::regex undefRe("#undef\\s+([a-zA-Z0-9_]+)"); + std::smatch m; + if (!std::regex_match(trimmed, m, undefRe)) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Invalid #undef statement!"); + } + defines.erase(m[1].str()); + } else if (trimmed.starts_with("#ifdef") || + trimmed.starts_with("#ifndef")) { + if (insideIfdef) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Nested #ifdef not allowed!"); + } + insideIfdef = true; + std::regex ifdefRe("#ifn?def\\s+([a-zA-Z0-9_]+)"); + std::smatch m; + if (!std::regex_match(trimmed, m, ifdefRe)) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Invalid #ifdef statement!"); + } + bool isIfdef = trimmed.starts_with("#ifdef"); + std::string name = m[1].str(); + ignoreLine = isIfdef ? !defines.count(name) : defines.count(name); + } else if (trimmed.starts_with("#else")) { + ignoreLine = insideIfdef && !ignoreLine; + } else if (trimmed.starts_with("#endif")) { + insideIfdef = false; + ignoreLine = false; + } else if (!ignoreLine && trimmed.starts_with("#include")) { + std::regex incRe("#include\\s+\"(.*)\""); + std::smatch m; + if (!std::regex_match(trimmed, m, incRe)) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Invalid #include!"); + } + std::string path = m[1].str(); + std::string fullPath = sourceDir + "/" + path; + std::ifstream incFile(fullPath); + if (!incFile) { + throw std::runtime_error( + "Line " + std::to_string(lineNum) + + ": Cannot open include: " + fullPath); + } + std::string incSrc; + { + std::ostringstream ss; + ss << incFile.rdbuf(); + incSrc = ss.str(); + } + result += preprocess(stripComments(incSrc), defines, sourceDir, defineOrder); + } else if (!ignoreLine) { + newLine = replaceDefines(line); + } + + result += newLine + "\n"; + } + + return result; +} + +std::string preprocFull(const std::string &src, + std::unordered_map &defines, + const std::string &sourceDir, + std::vector *defineOrder) { + return preprocess(stripComments(src), defines, sourceDir, defineOrder); +} + +} // namespace rspl diff --git a/cpp/src/preproc.h b/cpp/src/preproc.h new file mode 100644 index 0000000..a7fd4d0 --- /dev/null +++ b/cpp/src/preproc.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +namespace rspl { + +struct DefineEntry { + std::string name; + std::string value; +}; + +/// Strip C-style comments (// and /* */) from source +std::string stripComments(const std::string &source); + +/// Preprocess with C-style #define, #ifdef, #ifndef, #include, #undef. +/// @param defines map of name->value for predefined defines (modified in-place) +/// @param defineOrder if non-null, records defines in source order +std::string preprocess(const std::string &src, + std::unordered_map &defines, + const std::string &sourceDir = ".", + std::vector *defineOrder = nullptr); + +/// Convenience: stripComments + preprocess +std::string preprocFull(const std::string &src, + std::unordered_map &defines, + const std::string &sourceDir = ".", + std::vector *defineOrder = nullptr); + +} // namespace rspl diff --git a/cpp/src/registers.cpp b/cpp/src/registers.cpp new file mode 100644 index 0000000..a4402be --- /dev/null +++ b/cpp/src/registers.cpp @@ -0,0 +1,71 @@ +#include "registers.h" + +#include + +namespace rspl::reg { + +const std::vector REGS_SCALAR = { + "$zero", "$at", "$v0", "$v1", "$a0", "$a1", "$a2", "$a3", + "$t0", "$t1", "$t2", "$t3", "$t4", "$t5", "$t6", "$t7", + "$s0", "$s1", "$s2", "$s3", "$s4", "$s5", "$s6", "$s7", + "$t8", "$t9", "$k0", "$k1", "$gp", "$sp", "$fp", "$ra", +}; + +const std::vector REGS_VECTOR = { + "$v00", "$v01", "$v02", "$v03", "$v04", "$v05", "$v06", "$v07", + "$v08", "$v09", "$v10", "$v11", "$v12", "$v13", "$v14", "$v15", + "$v16", "$v17", "$v18", "$v19", "$v20", "$v21", "$v22", "$v23", + "$v24", "$v25", "$v26", "$v27", "$v28", "$v29", "$v30", "$v31", +}; + +const std::vector REGS_ALLOC_SCALAR = { + "$t0", "$t1", "$t2", "$t3", "$t4", "$t5", "$t6", "$t7", "$t8", "$t9", + "$k0", "$k1", "$sp", "$fp", + "$s0", "$s1", "$s2", "$s3", "$s4", "$s5", "$s6", "$s7", +}; + +const std::vector REGS_ALLOC_VECTOR = { + "$v01", "$v02", "$v03", "$v04", "$v05", "$v06", "$v07", + "$v08", "$v09", "$v10", "$v11", "$v12", "$v13", "$v14", "$v15", + "$v16", "$v17", "$v18", "$v19", "$v20", "$v21", "$v22", "$v23", + "$v24", "$v25", "$v26", "$v27", "$v28", +}; + +const std::vector REGS_FORBIDDEN = { + Reg::AT, Reg::GP, Reg::VTEMP0, +}; + +bool isVecReg(const std::string ®Name) { + return std::find(REGS_VECTOR.begin(), REGS_VECTOR.end(), regName) != + REGS_VECTOR.end(); +} + +static int indexIn(const std::string &name, + const std::vector &list) { + for (size_t i = 0; i < list.size(); ++i) { + if (list[i] == name) return static_cast(i); + } + return -1; +} + +const std::string *nextReg(const std::string ®Name, int offset) { + int idx = indexIn(regName, REGS_VECTOR); + if (idx >= 0 && idx + offset < (int)REGS_VECTOR.size()) { + return ®S_VECTOR[idx + offset]; + } + idx = indexIn(regName, REGS_SCALAR); + if (idx >= 0 && idx + offset < (int)REGS_SCALAR.size()) { + return ®S_SCALAR[idx + offset]; + } + return nullptr; +} + +const std::string *nextVecReg(const std::string ®Name) { + int idx = indexIn(regName, REGS_VECTOR); + if (idx >= 0 && idx + 1 < (int)REGS_VECTOR.size()) { + return ®S_VECTOR[idx + 1]; + } + return nullptr; +} + +} // namespace rspl::reg diff --git a/cpp/src/registers.h b/cpp/src/registers.h new file mode 100644 index 0000000..ef5a067 --- /dev/null +++ b/cpp/src/registers.h @@ -0,0 +1,120 @@ +#pragma once + +#include +#include + +namespace rspl::reg { + +// --- Register name constants ------------------------------------------ + +struct Reg { + static constexpr const char *AT = "$at"; + static constexpr const char *ZERO = "$zero"; + static constexpr const char *V0 = "$v0"; + static constexpr const char *V1 = "$v1"; + static constexpr const char *A0 = "$a0"; + static constexpr const char *A1 = "$a1"; + static constexpr const char *A2 = "$a2"; + static constexpr const char *A3 = "$a3"; + static constexpr const char *T0 = "$t0"; + static constexpr const char *T1 = "$t1"; + static constexpr const char *T2 = "$t2"; + static constexpr const char *T3 = "$t3"; + static constexpr const char *T4 = "$t4"; + static constexpr const char *T5 = "$t5"; + static constexpr const char *T6 = "$t6"; + static constexpr const char *T7 = "$t7"; + static constexpr const char *T8 = "$t8"; + static constexpr const char *T9 = "$t9"; + static constexpr const char *S0 = "$s0"; + static constexpr const char *S1 = "$s1"; + static constexpr const char *S2 = "$s2"; + static constexpr const char *S3 = "$s3"; + static constexpr const char *S4 = "$s4"; + static constexpr const char *S5 = "$s5"; + static constexpr const char *S6 = "$s6"; + static constexpr const char *S7 = "$s7"; + static constexpr const char *K0 = "$k0"; + static constexpr const char *K1 = "$k1"; + static constexpr const char *GP = "$gp"; + static constexpr const char *SP = "$sp"; + static constexpr const char *FP = "$fp"; + static constexpr const char *RA = "$ra"; + + static constexpr const char *V00 = "$v00"; + static constexpr const char *V01 = "$v01"; + static constexpr const char *V02 = "$v02"; + static constexpr const char *V03 = "$v03"; + static constexpr const char *V04 = "$v04"; + static constexpr const char *V05 = "$v05"; + static constexpr const char *V06 = "$v06"; + static constexpr const char *V07 = "$v07"; + static constexpr const char *V08 = "$v08"; + static constexpr const char *V09 = "$v09"; + static constexpr const char *V10 = "$v10"; + static constexpr const char *V11 = "$v11"; + static constexpr const char *V12 = "$v12"; + static constexpr const char *V13 = "$v13"; + static constexpr const char *V14 = "$v14"; + static constexpr const char *V15 = "$v15"; + static constexpr const char *V16 = "$v16"; + static constexpr const char *V17 = "$v17"; + static constexpr const char *V18 = "$v18"; + static constexpr const char *V19 = "$v19"; + static constexpr const char *V20 = "$v20"; + static constexpr const char *V21 = "$v21"; + static constexpr const char *V22 = "$v22"; + static constexpr const char *V23 = "$v23"; + static constexpr const char *V24 = "$v24"; + static constexpr const char *V25 = "$v25"; + static constexpr const char *V26 = "$v26"; + static constexpr const char *V27 = "$v27"; + static constexpr const char *V28 = "$v28"; + static constexpr const char *V29 = "$v29"; + static constexpr const char *V30 = "$v30"; + static constexpr const char *V31 = "$v31"; + + static constexpr const char *VZERO = "$v00"; + static constexpr const char *VTEMP0 = "$v29"; + static constexpr const char *VSHIFT = "$v30"; + static constexpr const char *VSHIFT8 = "$v31"; +}; + +struct RegCop0 { + static constexpr const char *DMA_BUSY = "COP0_DMA_BUSY"; + static constexpr const char *DP_START = "COP0_DP_START"; + static constexpr const char *DP_END = "COP0_DP_END"; + static constexpr const char *DP_CURRENT = "COP0_DP_CURRENT"; + static constexpr const char *DP_CLOCK = "COP0_DP_CLOCK"; + static constexpr const char *DMA_SPADDR = "COP0_DMA_SPADDR"; + static constexpr const char *DMA_RAMADDR = "COP0_DMA_RAMADDR"; + static constexpr const char *DMA_READ = "COP0_DMA_READ"; + static constexpr const char *DMA_WRITE = "COP0_DMA_WRITE"; + static constexpr const char *SP_STATUS = "COP0_SP_STATUS"; + static constexpr const char *DMA_FULL = "COP0_DMA_FULL"; +}; + +struct RegCop2 { + static constexpr const char *VCO = "$vc0"; + static constexpr const char *VCC = "$vcc"; + static constexpr const char *VCE = "$vce"; + static constexpr const char *ACC_MD = "COP2_ACC_MD"; + static constexpr const char *ACC_HI = "COP2_ACC_HI"; + static constexpr const char *ACC_LO = "COP2_ACC_LO"; +}; + +// --- Register lists --------------------------------------------------- + +extern const std::vector REGS_SCALAR; +extern const std::vector REGS_VECTOR; +extern const std::vector REGS_ALLOC_SCALAR; +extern const std::vector REGS_ALLOC_VECTOR; +extern const std::vector REGS_FORBIDDEN; + +// --- Register helpers ------------------------------------------------- + +bool isVecReg(const std::string ®Name); +const std::string *nextReg(const std::string ®Name, int offset = 1); +const std::string *nextVecReg(const std::string ®Name); + +} // namespace rspl::reg diff --git a/cpp/src/state.cpp b/cpp/src/state.cpp new file mode 100644 index 0000000..ece7927 --- /dev/null +++ b/cpp/src/state.cpp @@ -0,0 +1,498 @@ +#include "state.h" +#include "registers.h" +#include "types.h" + +#include +#include +#include +#include + +namespace rspl { + +// --- State constructor / reset ---------------------------------------- + +static const std::vector LABELS = { + "RSPQ_SCRATCH_MEM", +}; + +State::State() { reset(); } + +void State::reset() { + nextLabelId = 0; + func.clear(); + funcType.clear(); + argSize = 0; + line = 0; + scopeStack.clear(); + memVarMap.clear(); + outWarn.clear(); + outInfo.clear(); + funcMap.clear(); + barrierMaskMap.clear(); + regAllocAllowed = true; + + for (const auto &label : LABELS) { + declareMemVar(label, "u16", 1); + } +} + +// --- Error handling --------------------------------------------------- + +void State::throwError(const std::string &msg, + const std::string &context) const { + std::ostringstream oss; + oss << "Error in " << (func.empty() ? "(???)" : func) << ", line " + << (line == 0 ? "(???)" : std::to_string(line)) << ": " << msg + << "\n -> AST: " << context; + throw std::runtime_error(oss.str()); +} + +void State::logWarning(const std::string &msg, const std::string &context) { + std::ostringstream oss; + oss << "Warning in " << (func.empty() ? "(???)" : func) << ", line " + << (line == 0 ? "(???)" : std::to_string(line)) << ": " << msg + << "\n -> AST: " << context << "\n"; + outWarn += oss.str(); +} + +void State::logInfo(const std::string &msg) { outInfo += msg + '\n'; } + +// --- Function management ---------------------------------------------- + +void State::declareFunction(const std::string &name, + const std::vector &args, + bool isRelative) { + funcMap[name] = {name, args, isRelative}; +} + +void State::enterFunction(const std::string &name, const std::string &type, + int argSize_) { + func = name; + funcType = type; + argSize = argSize_ > 0 ? argSize_ : 0; + line = 0; + scopeStack.clear(); + pushScope(); + + // Declare built-in registers as variables + declareVar("ZERO", "u32", reg::Reg::ZERO, true); + declareVar("VZERO", "vec16", reg::Reg::VZERO, true); + declareVar("VSHIFT", "vec16", reg::Reg::VSHIFT, true); + declareVar("VSHIFT8", "vec16", reg::Reg::VSHIFT8, true); + declareVar("RA", "u32", reg::Reg::RA, false); + declareVar("VTEMP", "vec16", reg::Reg::VTEMP0, false, true); +} + +void State::leaveFunction() { + func.clear(); + funcType.clear(); + line = 0; + scopeStack.clear(); +} + +const FuncDef *State::getFunction(const std::string &name) const { + auto it = funcMap.find(name); + return it != funcMap.end() ? &it->second : nullptr; +} + +// --- Scope management ------------------------------------------------- + +Scope &State::getScope() { return scopeStack.back(); } + +void State::pushScope(const std::string &labelStart, + const std::string &labelEnd) { + Scope child = makeChildScope(); + if (!labelStart.empty() || !labelEnd.empty()) { + child.labelStart = + labelStart.empty() ? child.labelStart : labelStart; + child.labelEnd = labelEnd.empty() ? child.labelEnd : labelEnd; + } + scopeStack.push_back(std::move(child)); +} + +void State::popScope() { scopeStack.pop_back(); } + +Scope State::makeChildScope() const { + if (scopeStack.empty()) { + return Scope{}; + } + const auto &parent = scopeStack.back(); + return Scope{ + .varMap = parent.varMap, + .regVarMap = parent.regVarMap, + .varAliasMap = parent.varAliasMap, + .annotations = parent.annotations, + .labelStart = parent.labelStart, + .labelEnd = parent.labelEnd, + }; +} + +// --- Variable management ---------------------------------------------- + +void State::declareVar(const std::string &name, const std::string &type, + const std::string ®, bool isConst, + bool ignoreReserved) { + if (name.find(':') != std::string::npos) { + throwError("Variable name cannot contain a cast (':')!", {name}); + } + Scope &scope = getScope(); + if (reg.empty()) { + throwError("Cannot declare variable without register!", {name}); + } + if (!ignoreReserved && + std::find(reg::REGS_FORBIDDEN.begin(), reg::REGS_FORBIDDEN.end(), + reg) != reg::REGS_FORBIDDEN.end()) { + throwError("Cannot use reserved register '" + reg + "' for a variable!", + {name}); + } + + if (isVecType(type)) { + if (!reg::isVecReg(reg)) { + throwError("Cannot use scalar register for vector variable!", {name}); + } + } else { + if (reg::isVecReg(reg)) { + throwError("Cannot use vector register for scalar variable!", {name}); + } + } + + // Check for double-allocation + auto checkReg = [&](const std::string &r) { + auto it = scope.regVarMap.find(r); + if (it != scope.regVarMap.end()) { + throwError("Register '" + r + "' already used for variable '" + + it->second + "'!", + {name}); + } + }; + + checkReg(reg); + scope.varMap[name] = VarDef{reg, toTypeClass(type), {}, {}, {}, {}, 0, isConst, 0}; + scope.regVarMap[reg] = name; + + if (isTwoRegType(type)) { + const std::string *nextR = reg::nextReg(reg); + if (!nextR) throwError("No next register for two-reg type!", {name}); + checkReg(*nextR); + scope.regVarMap[*nextR] = name; + } +} + +void State::declareVarAlias(const std::string &aliasName, + const std::string &varName) { + getRequiredVar(varName, "alias"); + Scope &scope = getScope(); + auto it = scope.varAliasMap.find(varName); + const std::string &realName = (it != scope.varAliasMap.end()) + ? it->second + : varName; + scope.varAliasMap[aliasName] = realName; +} + +void State::undefVar(const std::string &varName) { + Scope &scope = getScope(); + + scope.varAliasMap.erase(varName); + std::vector toErase; + for (const auto &[alias, target] : scope.varAliasMap) { + if (target == varName) toErase.push_back(alias); + } + for (const auto &a : toErase) { + scope.varAliasMap.erase(a); + } + + std::string resolved = varName; + auto aliasIt = scope.varAliasMap.find(varName); + if (aliasIt != scope.varAliasMap.end()) { + resolved = aliasIt->second; + } + + auto varIt = scope.varMap.find(resolved); + if (varIt == scope.varMap.end()) { + throwError("Variable " + resolved + " not known!"); + } + + // Free registers + scope.regVarMap.erase(varIt->second.reg); + if (isTwoRegType(varIt->second.type)) { + const std::string *nextR = reg::nextReg(varIt->second.reg); + if (nextR) scope.regVarMap.erase(*nextR); + } + scope.varMap.erase(varIt); +} + +VarDef *State::getVar(const std::string &name) { + Scope &scope = getScope(); + std::string nameNorm = name; + auto colonPos = nameNorm.find(':'); + if (colonPos != std::string::npos) { + nameNorm = nameNorm.substr(0, colonPos); + } + auto aliasIt = scope.varAliasMap.find(nameNorm); + if (aliasIt != scope.varAliasMap.end()) { + nameNorm = aliasIt->second; + } + auto it = scope.varMap.find(nameNorm); + return it != scope.varMap.end() ? &it->second : nullptr; +} + +const VarDef *State::getRequiredVar(const std::string &name, + const std::string &contextName, + const std::string &context) { + VarDef *var = getVar(name); + if (!var) { + // Fallback: check memory variable map + auto memIt = memVarMap.find(name); + if (memIt != memVarMap.end()) { + static thread_local VarDef memVar; + memVar = VarDef{}; + memVar.type = toTypeClass(memIt->second.type); + memVar.name = memIt->second.name; + memVar.reg = "%lo(" + memIt->second.name + ")"; // Use as label ref + return &memVar; + } + throwError(contextName + " Variable " + name + " not known!", context); + } + return var; +} + +VarDef State::getRequiredVarCopy(const std::string &name, + const std::string &contextName, + const std::string &context) { + const VarDef *var = getRequiredVar(name, contextName, context); + VarDef copy = *var; + // Store the original variable name (without cast) for macro arg passing + auto cp = name.find(':'); + copy.name = (cp != std::string::npos) ? name.substr(0, cp) : name; + + // Handle cast suffix + auto colonPos = name.find(':'); + if (colonPos != std::string::npos) { + std::string castStr = name.substr(colonPos + 1); + copy.originalType = copy.type; + copy.castType = toCastType(castStr); + + if (isVecType(copy.type)) { + if (std::find(VEC_CASTS.begin(), VEC_CASTS.end(), castStr) == + VEC_CASTS.end()) { + throwError("Invalid cast type '" + castStr + "' for variable " + + name + ", expected: uint,sint,ufract,sfract!", + context); + } + if (copy.type == TypeClass::Vec32 && + (toCastType(castStr) == CastType::Sfract || toCastType(castStr) == CastType::Ufract)) { + const std::string *nextV = reg::nextVecReg(copy.reg); + if (nextV) copy.reg = *nextV; + } + copy.type = TypeClass::Vec16; + } else { + if (std::find(SCALAR_TYPES.begin(), SCALAR_TYPES.end(), castStr) == + SCALAR_TYPES.end()) { + throwError( + "Invalid cast type '" + castStr + "' for variable " + name + + ", expected: s8,u8,s16,u16,s32,u32", + context); + } + copy.type = toTypeClass(castStr); + } + } + return copy; +} + +const std::string *State::getVarReg(const std::string &name) const { + const Scope &scope = scopeStack.back(); + std::string nameNorm = name; + auto colonPos = nameNorm.find(':'); + if (colonPos != std::string::npos) { + nameNorm = nameNorm.substr(0, colonPos); + } + auto aliasIt = scope.varAliasMap.find(nameNorm); + if (aliasIt != scope.varAliasMap.end()) { + nameNorm = aliasIt->second; + } + auto it = scope.varMap.find(nameNorm); + return it != scope.varMap.end() ? &it->second.reg : nullptr; +} + +bool State::varExists(const std::string &name) const { + const Scope &scope = scopeStack.back(); + std::string nameNorm = name; + auto colonPos = nameNorm.find(':'); + if (colonPos != std::string::npos) { + nameNorm = nameNorm.substr(0, colonPos); + } + auto aliasIt = scope.varAliasMap.find(nameNorm); + if (aliasIt != scope.varAliasMap.end()) { + nameNorm = aliasIt->second; + } + return scope.varMap.count(nameNorm) > 0; +} + +void State::markVarModified(const std::string &name) { + Scope &scope = getScope(); + std::string nameNorm = name; + auto colonPos = nameNorm.find(':'); + if (colonPos != std::string::npos) { + nameNorm = nameNorm.substr(0, colonPos); + } + auto aliasIt = scope.varAliasMap.find(nameNorm); + if (aliasIt != scope.varAliasMap.end()) { + nameNorm = aliasIt->second; + } + auto it = scope.varMap.find(nameNorm); + if (it == scope.varMap.end()) { + throwError("Variable " + name + " not known!"); + } + it->second.modifyCount++; +} + +// --- Memory variables ------------------------------------------------- + +void State::declareMemVar(const std::string &name, const std::string &type, + int arraySize) { + memVarMap[name] = {name, type, arraySize}; +} + +const MemVarDef *State::getRequiredMem(const std::string &name, + const std::string &contextName, + const std::string &context) const { + auto it = memVarMap.find(name); + if (it == memVarMap.end()) { + throwError(contextName + " Memory-Var " + name + " not known!", context); + } + return &it->second; +} + +const MemVarDef *State::getMemVarOrNull(const std::string &name) const { + auto it = memVarMap.find(name); + return it != memVarMap.end() ? &it->second : nullptr; +} + +VarOrMem State::getRequiredVarOrMem(const std::string &name, + const std::string &contextName, + const std::string &context) const { + const Scope &scope = scopeStack.back(); + + // Check memory map first + auto memIt = memVarMap.find(name); + if (memIt != memVarMap.end()) { + return {memIt->second.name, memIt->second.type, "", + memIt->second.arraySize}; + } + + // Check variable scope + std::string nameNorm = name; + auto aliasIt = scope.varAliasMap.find(nameNorm); + if (aliasIt != scope.varAliasMap.end()) { + nameNorm = aliasIt->second; + } + auto varIt = scope.varMap.find(nameNorm); + if (varIt != scope.varMap.end()) { + return {varIt->first, toString(varIt->second.type), varIt->second.reg, 1}; + } + + throwError(contextName + " Variable/Memory " + name + " not known!", + context); + return {}; // unreachable +} + +// --- Register allocation ---------------------------------------------- + +std::string State::allocRegister(const std::string &type) { + if (!regAllocAllowed) { + throwError("Register allocation not allowed in this function!"); + } + + bool reverse = (funcType == "command"); + const auto ®List = isVecType(type) ? reg::REGS_ALLOC_VECTOR + : reg::REGS_ALLOC_SCALAR; + const Scope &scope = getScope(); + bool twoRegs = isTwoRegType(type); + + auto tryAlloc = [&](const std::string ®) -> std::string { + if (scope.regVarMap.count(reg)) return {}; + if (twoRegs) { + const std::string *nextR = reg::nextReg(reg); + if (!nextR || + std::find(regList.begin(), regList.end(), *nextR) == + regList.end() || + scope.regVarMap.count(*nextR)) { + return {}; + } + } + return reg; + }; + + if (reverse) { + for (auto it = regList.rbegin(); it != regList.rend(); ++it) { + std::string found = tryAlloc(*it); + if (!found.empty()) return found; + } + } else { + for (const auto ® : regList) { + std::string found = tryAlloc(reg); + if (!found.empty()) return found; + } + } + + throwError("Out of free registers!"); + return {}; // unreachable +} + +// --- Labels ----------------------------------------------------------- + +std::string State::generateLabel() { + ++nextLabelId; + char buf[64]; + snprintf(buf, sizeof(buf), "LABEL_%s_%04X", + func.c_str(), nextLabelId); + return buf; +} + +// --- Annotations ------------------------------------------------------ + +void State::addAnnotation(const std::string &name, + const std::string &value) { + Scope &scope = getScope(); + scope.annotations.push_back({name, value}); +} + +std::vector State::getAnnotations( + const std::string &name) const { + if (scopeStack.empty()) return {}; + const auto &annos = scopeStack.back().annotations; + if (name.empty()) return annos; + + std::vector result; + for (const auto &a : annos) { + if (a.name == name) result.push_back(a); + } + return result; +} + +void State::clearAnnotations() { + if (!scopeStack.empty()) { + scopeStack.back().annotations.clear(); + } +} + +// --- Barrier masks ---------------------------------------------------- + +uint32_t State::getBarrierMask(const std::string &name) { + auto it = barrierMaskMap.find(name); + if (it != barrierMaskMap.end()) return it->second; + + int len = barrierMaskMap.size(); + if (len >= 32) { + throwError("Too many different barriers, only up to 32 are supported!"); + } + uint32_t mask = (1u << len); + barrierMaskMap[name] = mask; + return mask; +} + +// --- Global instance -------------------------------------------------- + +State state; + +} // namespace rspl diff --git a/cpp/src/state.h b/cpp/src/state.h new file mode 100644 index 0000000..371e833 --- /dev/null +++ b/cpp/src/state.h @@ -0,0 +1,171 @@ +#pragma once + +#include "ast.h" + +#include +#include +#include +#include + +namespace rspl { + +// --- Variable definition ---------------------------------------------- + +struct VarDef { + std::string reg; + TypeClass type = TypeClass::Unknown; + std::string name; // for memory label references + TypeClass originalType = TypeClass::Unknown; // before cast + CastType castType = CastType::None; // e.g. ufract, sfract, s8 + std::string swizzle; // optional swizzle suffix + double value = 0.0; // numeric value when reg is empty + bool isConst = false; + int modifyCount = 0; + + // Backward-compat accessors for code that still uses strings + std::string typeStr() const { return toString(type); } + std::string originalTypeStr() const { return toString(originalType); } + std::string castTypeStr() const { return toString(castType); } +}; + +struct MemVarDef { + std::string name; + std::string type; + int arraySize = 1; +}; + +// Union type for getRequiredVarOrMem — either a register variable or +// a memory label. Check `reg` to see which. +struct VarOrMem { + std::string name; + std::string type; + std::string reg; // empty -> this is a memory variable + int arraySize = 1; // only set for memory variables +}; + +struct FuncDef { + std::string name; + std::vector args; + bool isRelative = false; +}; + +struct AnnotationDef { + std::string name; + std::string value; +}; + +// --- Scope ------------------------------------------------------------ + +struct Scope { + // Variable name -> definition (inherited from parent scope via copy-down) + std::unordered_map varMap; + // Register -> variable name (for collision detection) + std::unordered_map regVarMap; + // Alias -> real variable name (for macro args) + std::unordered_map varAliasMap; + // Scope-local annotations + std::vector annotations; + // Label targets for break/continue + std::string labelStart; + std::string labelEnd; +}; + +// --- State ------------------------------------------------------------ + +class State { +public: + State(); + + // -- Reset ----------------------------------------------------------- + void reset(); + + // -- Error / warning / info ------------------------------------------ + [[noreturn]] void throwError(const std::string &msg, + const std::string &context = "{}") const; + + void logWarning(const std::string &msg, const std::string &context = "{}"); + void logInfo(const std::string &msg); + + // -- Source tracking ------------------------------------------------- + std::vector sourceLines; + std::string func; // current function name + std::string funcType; // "function", "command", "macro" + int argSize = 0; + uint32_t line = 0; + std::string outWarn; + std::string outInfo; + + // -- Function management --------------------------------------------- + void declareFunction(const std::string &name, + const std::vector &args, + bool isRelative = false); + void enterFunction(const std::string &name, const std::string &type, + int argSize); + void leaveFunction(); + const FuncDef *getFunction(const std::string &name) const; + + // -- Scope management ------------------------------------------------ + Scope &getScope(); + void pushScope(const std::string &labelStart = "", + const std::string &labelEnd = ""); + void popScope(); + + // -- Variable management --------------------------------------------- + void declareVar(const std::string &name, const std::string &type, + const std::string ®, bool isConst = false, + bool ignoreReserved = false); + void declareVarAlias(const std::string &aliasName, + const std::string &varName); + void undefVar(const std::string &varName); + VarDef *getVar(const std::string &name); + const VarDef *getRequiredVar(const std::string &name, + const std::string &contextName, + const std::string &context = "{}"); + const std::string *getVarReg(const std::string &name) const; + bool varExists(const std::string &name) const; + void markVarModified(const std::string &name); + VarDef getRequiredVarCopy(const std::string &name, + const std::string &contextName, + const std::string &context = "{}"); + + // -- Memory variables (global state labels) -------------------------- + void declareMemVar(const std::string &name, const std::string &type, + int arraySize); + const MemVarDef *getRequiredMem(const std::string &name, + const std::string &contextName, + const std::string &context = "{}") const; + const MemVarDef *getMemVarOrNull(const std::string &name) const; + VarOrMem getRequiredVarOrMem(const std::string &name, + const std::string &contextName, + const std::string &context = "{}") const; + + // -- Register allocation --------------------------------------------- + std::string allocRegister(const std::string &type); + bool regAllocAllowed = true; + + // -- Labels ---------------------------------------------------------- + std::string generateLabel(); + + // -- Annotations ----------------------------------------------------- + void addAnnotation(const std::string &name, const std::string &value); + std::vector getAnnotations( + const std::string &name = "") const; + void clearAnnotations(); + + // -- Barrier masks --------------------------------------------------- + uint32_t getBarrierMask(const std::string &name); + +private: + int nextLabelId = 0; + std::vector scopeStack; + std::unordered_map memVarMap; + std::unordered_map funcMap; + std::unordered_map barrierMaskMap; + + Scope makeChildScope() const; +}; + +// Global state instance (mirrors JS `export default state`) +extern State state; + +} // namespace rspl diff --git a/cpp/src/swizzle.cpp b/cpp/src/swizzle.cpp new file mode 100644 index 0000000..17d1f60 --- /dev/null +++ b/cpp/src/swizzle.cpp @@ -0,0 +1,33 @@ +#include "swizzle.h" + +namespace rspl { + +const std::unordered_map SWIZZLE_MAP = { + {"", ".v"}, {"xyzwXYZW", ".v"}, {"xxzzXXZZ", ".q0"}, + {"yywwYYWW", ".q1"}, {"xxxxXXXX", ".h0"}, {"yyyyYYYY", ".h1"}, + {"zzzzZZZZ", ".h2"}, {"wwwwWWWW", ".h3"}, {"xxxxxxxx", ".e0"}, + {"yyyyyyyy", ".e1"}, {"zzzzzzzz", ".e2"}, {"wwwwwwww", ".e3"}, + {"XXXXXXXX", ".e4"}, {"YYYYYYYY", ".e5"}, {"ZZZZZZZZ", ".e6"}, + {"WWWWWWWW", ".e7"}, {"x", ".e0"}, {"y", ".e1"}, + {"z", ".e2"}, {"w", ".e3"}, {"X", ".e4"}, + {"Y", ".e5"}, {"Z", ".e6"}, {"W", ".e7"}, +}; + +const std::unordered_map SWIZZLE_SCALAR_IDX = { + {'x', 0}, {'y', 1}, {'z', 2}, {'w', 3}, + {'X', 4}, {'Y', 5}, {'Z', 6}, {'W', 7}, +}; + +const std::unordered_map POW2_SWIZZLE_VAR = { + {0, {"$v00", "x"}}, {1, {"$v30", "W"}}, + {2, {"$v30", "Z"}}, {4, {"$v30", "Y"}}, + {8, {"$v30", "X"}}, {16, {"$v30", "w"}}, + {32, {"$v30", "z"}}, {64, {"$v30", "y"}}, + {128, {"$v30", "x"}}, {256, {"$v31", "W"}}, + {512, {"$v31", "Z"}}, {1024, {"$v31", "Y"}}, + {2048, {"$v31", "X"}}, {4096, {"$v31", "w"}}, + {8192, {"$v31", "z"}}, {16384, {"$v31", "y"}}, + {32768, {"$v31", "x"}}, +}; + +} // namespace rspl diff --git a/cpp/src/swizzle.h b/cpp/src/swizzle.h new file mode 100644 index 0000000..de200b2 --- /dev/null +++ b/cpp/src/swizzle.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace rspl { + +// Swizzle string -> MIPS element suffix (.e0, .h0, etc.) +extern const std::unordered_map SWIZZLE_MAP; + +// Swizzle lane character -> byte offset index +extern const std::unordered_map SWIZZLE_SCALAR_IDX; + +// Power-of-two values -> vector register + swizzle reference +struct Pow2SwizzleRef { + std::string reg; // e.g. "$v30" (VSHIFT) + std::string swizzle; // e.g. "x" +}; +extern const std::unordered_map POW2_SWIZZLE_VAR; + +// Check if swizzle is a single-lane access (.x to .W) +inline bool isScalarSwizzle(const std::string &s) { return s.size() == 1; } + +} // namespace rspl diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp new file mode 100644 index 0000000..fb1d4ba --- /dev/null +++ b/cpp/src/types.cpp @@ -0,0 +1,136 @@ +#include "types.h" + +#include + +namespace rspl { + +// --- string -> enum conversions ----------------------------------------- + +TypeClass toTypeClass(const std::string &s) { + if (s == "vec32") return TypeClass::Vec32; + if (s == "vec16") return TypeClass::Vec16; + if (s == "u32") return TypeClass::U32; + if (s == "u16") return TypeClass::U16; + if (s == "u8") return TypeClass::U8; + if (s == "s32") return TypeClass::S32; + if (s == "s16") return TypeClass::S16; + if (s == "s8") return TypeClass::S8; + if (s == "sfract") return TypeClass::Sfract; + if (s == "ufract") return TypeClass::Ufract; + if (s == "sint") return TypeClass::Sint; + if (s == "uint") return TypeClass::Uint; + if (s.empty()) return TypeClass::Unknown; + throw std::runtime_error("Unknown type class: " + s); +} + +std::string toString(TypeClass tc) { + switch (tc) { + case TypeClass::Vec32: return "vec32"; + case TypeClass::Vec16: return "vec16"; + case TypeClass::U32: return "u32"; + case TypeClass::U16: return "u16"; + case TypeClass::U8: return "u8"; + case TypeClass::S32: return "s32"; + case TypeClass::S16: return "s16"; + case TypeClass::S8: return "s8"; + case TypeClass::Sfract: return "sfract"; + case TypeClass::Ufract: return "ufract"; + case TypeClass::Sint: return "sint"; + case TypeClass::Uint: return "uint"; + case TypeClass::Unknown: return ""; + } + return ""; +} + +CastType toCastType(const std::string &s) { + if (s.empty()) return CastType::None; + if (s == "sfract") return CastType::Sfract; + if (s == "ufract") return CastType::Ufract; + if (s == "sint") return CastType::Sint; + if (s == "uint") return CastType::Uint; + if (s == "s8") return CastType::S8; + if (s == "s16") return CastType::S16; + // Scalar type casts (u32, s16, etc.) are not vector cast types + if (s == "u32" || s == "u16" || s == "u8" || + s == "s32" || s == "s16" || s == "s8") + return CastType::None; + throw std::runtime_error("Unknown cast type: " + s); +} + +std::string toString(CastType ct) { + switch (ct) { + case CastType::None: return ""; + case CastType::Sfract: return "sfract"; + case CastType::Ufract: return "ufract"; + case CastType::Sint: return "sint"; + case CastType::Uint: return "uint"; + case CastType::S8: return "s8"; + case CastType::S16: return "s16"; + } + return ""; +} + +FuncType toFuncType(const std::string &s) { + if (s == "function") return FuncType::Function; + if (s == "command") return FuncType::Command; + if (s == "macro") return FuncType::Macro; + throw std::runtime_error("Unknown function type: " + s); +} + +std::string toString(FuncType ft) { + switch (ft) { + case FuncType::Function: return "function"; + case FuncType::Command: return "command"; + case FuncType::Macro: return "macro"; + } + return "function"; +} + +ArgType toArgType(const std::string &s) { + if (s == "var") return ArgType::Var; + if (s == "num") return ArgType::Num; + if (s == "string") return ArgType::String; + throw std::runtime_error("Unknown arg type: " + s); +} + +std::string toString(ArgType at) { + switch (at) { + case ArgType::Var: return "var"; + case ArgType::Num: return "num"; + case ArgType::String: return "string"; + } + return "var"; +} + +const std::unordered_map TYPE_SIZE = { + {"s8", 1}, {"u8", 1}, {"s16", 2}, {"u16", 2}, + {"s32", 4}, {"u32", 4}, {"vec16", 16}, {"vec32", 32}, +}; + +const std::unordered_map TYPE_ALIGNMENT = { + {"s8", 0}, {"u8", 0}, {"s16", 1}, {"u16", 1}, + {"s32", 2}, {"u32", 2}, {"vec16", 4}, {"vec32", 4}, +}; + +const std::unordered_map TYPE_REG_COUNT = { + {"s8", 1}, {"u8", 1}, {"s16", 1}, {"u16", 1}, + {"s32", 1}, {"u32", 1}, {"vec16", 1}, {"vec32", 2}, +}; + +const std::unordered_map TYPE_ASM_DEF = { + {"s8", {"byte", 1}}, + {"u8", {"byte", 1}}, + {"s16", {"half", 1}}, + {"u16", {"half", 1}}, + {"s32", {"word", 1}}, + {"u32", {"word", 1}}, + {"vec16", {"half", 8}}, + {"vec32", {"half", 16}}, +}; + +const std::vector SCALAR_TYPES = { + "s8", "u8", "s16", "u16", "s32", "u32"}; + +const std::vector VEC_CASTS = {"uint", "sint", "ufract", "sfract"}; + +} // namespace rspl diff --git a/cpp/src/types.h b/cpp/src/types.h new file mode 100644 index 0000000..09e6237 --- /dev/null +++ b/cpp/src/types.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include + +namespace rspl { + +// --- Type enums (replaces string-based type checks in hot paths) ---------- + +enum class TypeClass : uint8_t { + Unknown = 0, // must be 0 so VarDef{} zero-inits to Unknown, not Vec32 + Vec32, + Vec16, + U32, + U16, + U8, + S32, + S16, + S8, + Sfract, + Ufract, + Sint, + Uint, +}; + +enum class CastType : uint8_t { + None, + Sfract, + Ufract, + Sint, + Uint, + S8, + S16 +}; + +enum class FuncType : uint8_t { Function, Command, Macro }; + +enum class ArgType : uint8_t { Var, Num, String }; + +// --- Conversion: string <-> enum ---------------------------------------- + +TypeClass toTypeClass(const std::string &s); +std::string toString(TypeClass tc); +CastType toCastType(const std::string &s); +std::string toString(CastType ct); +FuncType toFuncType(const std::string &s); +std::string toString(FuncType ft); +ArgType toArgType(const std::string &s); +std::string toString(ArgType at); + +// --- Type checks (now on enums, single-cycle) --------------------------- + +inline bool isVecType(TypeClass tc) { + return tc == TypeClass::Vec32 || tc == TypeClass::Vec16; +} +inline bool isTwoRegType(TypeClass tc) { return tc == TypeClass::Vec32; } +inline bool isSigned(TypeClass tc) { + return tc == TypeClass::S32 || tc == TypeClass::S16 || tc == TypeClass::S8 || + tc == TypeClass::Sint || tc == TypeClass::Sfract; +} + +// --- Backward-compat helpers for places that still use strings ----------- + +// These overloads exist so call sites that currently pass strings +// work during the transition. +inline bool isVecType(const std::string &s) { return isVecType(toTypeClass(s)); } +inline bool isTwoRegType(const std::string &s) { return s == "vec32"; } + +// --- Type size / alignment / reg count ---------------------------------- + +struct TypeAsmDef { + std::string type; // "byte", "half", "word" + int count; +}; + +extern const std::unordered_map TYPE_SIZE; +extern const std::unordered_map TYPE_ALIGNMENT; +extern const std::unordered_map TYPE_REG_COUNT; +extern const std::unordered_map TYPE_ASM_DEF; + +// --- Type lists --------------------------------------------------------- + +extern const std::vector SCALAR_TYPES; +extern const std::vector VEC_CASTS; + +// --- Misc helpers ------------------------------------------------------- + +inline std::string toHex(int64_t val, int pad = 2) { + char buf[32]; + snprintf(buf, sizeof(buf), "0x%0*llX", pad, + static_cast(val)); + return buf; +} + +inline bool u32InS16Range(uint32_t valueU32) { + return valueU32 <= 0x7FFF || valueU32 >= 0xFFFF8000; +} + +inline bool u32InU16Range(uint32_t valueU32) { return valueU32 <= 0xFFFF; } + +inline uint32_t f32ToFP32(float valueF32) { + return static_cast(static_cast(valueF32 * (1 << 16))); +} + +} // namespace rspl \ No newline at end of file diff --git a/cpp/tests/diff_util.h b/cpp/tests/diff_util.h new file mode 100644 index 0000000..980da48 --- /dev/null +++ b/cpp/tests/diff_util.h @@ -0,0 +1,71 @@ +#pragma once +#include +#include +#include + +// Print up to `maxDiffs` line-by-line differences with `ctx` surrounding +// context lines. Shown as "- expected" / "+ actual" pairs. +inline std::string diffLines(const std::string &expected, + const std::string &actual, int ctx = 2, + int maxDiffs = 25) { + std::vector e, a; + auto split = [](const std::string &s, std::vector &v) { + std::istringstream ss(s); + std::string l; + while (std::getline(ss, l)) v.push_back(l); + }; + split(expected, e); + split(actual, a); + + // Simple equal-prefix walk, then show context around mismatches. + size_t i = 0; + int shown = 0; + std::ostringstream out; + while (i < e.size() || i < a.size()) { + bool eq = + i < e.size() && i < a.size() && e[i] == a[i]; + if (eq) { ++i; continue; } + if (shown >= maxDiffs) break; + + // Print context before + size_t cs = (i > (size_t)ctx) ? i - ctx : 0; + if (shown > 0) out << "--\n"; + out << "Line " << (i + 1) << ":\n"; + for (size_t k = cs; k < i && k < e.size(); ++k) + out << " " << e[k] << "\n"; + // Print diff + if (i < e.size()) out << "- " << e[i] << "\n"; + if (i < a.size()) out << "+ " << a[i] << "\n"; + // Print context after + size_t ce = std::min(i + ctx + 1, + std::max(e.size(), a.size())); + for (size_t k = i + 1; k < ce; ++k) { + if (k < e.size() && k < a.size() && e[k] == a[k]) + out << " " << e[k] << "\n"; + else + break; + } + ++i; ++shown; + } + if (shown >= maxDiffs) + out << "... (" << maxDiffs << " diffs shown, more omitted)\n"; + + // If no diffs found but lengths differ, show tail + if (shown == 0 && e.size() != a.size()) { + out << "Length mismatch: expected " << e.size() << " lines, got " + << a.size() << "\n"; + size_t start = std::min(e.size(), a.size()); + for (size_t k = start; k < std::min(start + 5, e.size()); ++k) + out << "- " << e[k] << "\n"; + for (size_t k = start; k < std::min(start + 5, a.size()); ++k) + out << "+ " << a[k] << "\n"; + } + return out.str(); +} + +#define REQUIRE_ASM_EQ(expected, actual) \ + do { \ + if ((expected) != (actual)) { \ + FAIL_CHECK("ASM mismatch:\n" << diffLines(expected, actual)); \ + } \ + } while (0) diff --git a/cpp/tests/test_annotations.cpp b/cpp/tests/test_annotations.cpp new file mode 100644 index 0000000..b4db5bc --- /dev/null +++ b/cpp/tests/test_annotations.cpp @@ -0,0 +1,78 @@ +#include +#include "pipeline.h" + +TEST_CASE("Annotations - Align (function)", "[annotations]") { + auto result = rspl::transpileSource( + R"( +@Align(8) +function test() +{ + exit; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(.align 3 +test: + j RSPQ_Loop + nop + jr $ra + nop)"); +} + +TEST_CASE("Annotations - Relative (function)", "[annotations]") { + auto result = rspl::transpileSource( + R"( +@Relative +function target_rel() {} +function target_abs() {} +function caller() { + target_rel(); + target_abs(); +} +)", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(target_rel: + jr $ra + nop +target_abs: + jr $ra + nop +caller: + bgezal $zero, target_rel + nop + jal target_abs + nop + jr $ra + nop)"); +} + +TEST_CASE("Annotations - Relative (caller)", "[annotations]") { + auto result = rspl::transpileSource( + R"( +function target_rel() {} +function target_abs() {} +function caller() { + @Relative target_rel(); + target_abs(); +} +)", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(target_rel: + jr $ra + nop +target_abs: + jr $ra + nop +caller: + bgezal $zero, target_rel + nop + jal target_abs + nop + jr $ra + nop)"); +} diff --git a/cpp/tests/test_ast.cpp b/cpp/tests/test_ast.cpp new file mode 100644 index 0000000..f9cbed0 --- /dev/null +++ b/cpp/tests/test_ast.cpp @@ -0,0 +1,88 @@ +#include +#include "ast.h" + +using namespace rspl::ast; + +TEST_CASE("parse minimal function", "[ast]") { + auto prog = parseJson(R"({ + "includes": [], + "states": [], + "functions": [{ + "annotations": [], + "type": "function", + "resultType": null, + "name": "test", + "args": [], + "body": { + "type": "scopedBlock", + "statements": [{ + "type": "varDeclMulti", + "varType": "u32", + "reg": "$t0", + "varNames": ["a","b","c"], + "isConst": false, + "line": 2 + }], + "line": 1 + } + }], + "postIncludes": [] + })"); + + REQUIRE(prog.functions.size() == 1); + REQUIRE(prog.functions[0].name == "test"); + REQUIRE(toString(prog.functions[0].type) == "function"); + REQUIRE(prog.functions[0].body != nullptr); + REQUIRE(prog.functions[0].body->statements.size() == 1); +} + +TEST_CASE("parse calcNum with plain number", "[ast]") { + auto prog = parseJson(R"({ + "includes": [], "states": [], + "functions": [{ + "annotations": [], "type": "function", "resultType": null, "name": "f", "args": [], + "body": { + "type": "scopedBlock", "statements": [{ + "type": "varDeclAssign", + "varType": "u32", "varName": "x", + "calc": { "type": "calcNum", "right": 42 }, + "isConst": false, "line": 1 + }], "line": 1 + } + }], + "postIncludes": [] + })"); + REQUIRE(prog.functions.size() == 1); +} + +TEST_CASE("parse compare in if statement", "[ast]") { + auto prog = parseJson(R"({ + "includes": [], "states": [], + "functions": [{ + "annotations": [], "type": "function", "resultType": null, "name": "f", "args": [], + "body": { + "type": "scopedBlock", "statements": [{ + "type": "if", + "compare": { + "left": {"type":"var","value":"a"}, + "op": ">", + "right": {"type":"num","value":10} + }, + "blockIf": { + "type": "scopedBlock", "statements": [{ + "type": "varAssignCalc", + "varName": "b", + "assignType": "=", + "calc": { "type": "calcNum", "right": 5 }, + "line": 2 + }], "line": 2 + }, + "blockElse": null, + "line": 1 + }], "line": 1 + } + }], + "postIncludes": [] + })"); + REQUIRE(prog.functions.size() == 1); +} diff --git a/cpp/tests/test_branchConst.cpp b/cpp/tests/test_branchConst.cpp new file mode 100644 index 0000000..8e682da --- /dev/null +++ b/cpp/tests/test_branchConst.cpp @@ -0,0 +1,332 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Branch (Var vs. Const) - Equal - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a == 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + addiu $at, $zero, 42 + bne $v0, $at, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Equal - U32 (big number)", + "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a == 0x112233) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + lui $at, 0x11 + ori $at, $at, 0x2233 + bne $v0, $at, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Not-Equal - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a != 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + addiu $at, $zero, 42 + beq $v0, $at, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Greater - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a > 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltiu $at, $v0, 43 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Greater - U32 (big number)", + "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a > 0xFFFEFFFF) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + lui $at, 0xFFFF + sltu $at, $v0, $at + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Less - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a < 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltiu $at, $v0, 42 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Greater-Than - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a >= 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltiu $at, $v0, 42 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Less-Than - U32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a <= 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltiu $at, $v0, 43 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +// Signed + +TEST_CASE("Branch (Var vs. Const) - Equal - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a == 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + addiu $at, $zero, 42 + bne $v0, $at, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Not-Equal - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a != 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + addiu $at, $zero, 42 + beq $v0, $at, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Greater - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a > 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slti $at, $v0, 43 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Less - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a < 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slti $at, $v0, 42 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Greater-Than - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a >= 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slti $at, $v0, 42 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Const) - Less-Than - S32", "[branchConst]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a <= 42) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slti $at, $v0, 43 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_branchVar.cpp b/cpp/tests/test_branchVar.cpp new file mode 100644 index 0000000..e6d4bf8 --- /dev/null +++ b/cpp/tests/test_branchVar.cpp @@ -0,0 +1,278 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Branch (Var vs. Var) - Equal - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a == b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bne $v0, $v1, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Not-Equal - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a != b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + beq $v0, $v1, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Greater - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a > b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltu $at, $v1, $v0 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Less - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a < b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltu $at, $v0, $v1 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Greater-Than - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a >= b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltu $at, $v0, $v1 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Less-Than - U32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a,b; + if(a <= b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + sltu $at, $v1, $v0 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +// Signed + +TEST_CASE("Branch (Var vs. Var) - Equal - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a == b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bne $v0, $v1, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Not-Equal - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a != b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + beq $v0, $v1, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Greater - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a > b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slt $at, $v1, $v0 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Less - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a < b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slt $at, $v0, $v1 + beq $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Greater-Than - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a >= b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slt $at, $v0, $v1 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. Var) - Less-Than - S32", "[branchVar]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a,b; + if(a <= b) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + slt $at, $v1, $v0 + bne $at, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_branchZero.cpp b/cpp/tests/test_branchZero.cpp new file mode 100644 index 0000000..9d1bee9 --- /dev/null +++ b/cpp/tests/test_branchZero.cpp @@ -0,0 +1,270 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Branch (Var vs. 0) - Equal - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a == 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bne $v0, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Not-Equal - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a != 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + beq $v0, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Greater - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a > 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + blez $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Less - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a < 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bgez $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Greater-Equal - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a >= 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bltz $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Less-Equal - U32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + u32<$v0> a; + if(a <= 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bgtz $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +// Signed + +TEST_CASE("Branch (Var vs. 0) - Equal - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a == 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bne $v0, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Not-Equal - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a != 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + beq $v0, $zero, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Greater - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a > 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + blez $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Less - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a < 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bgez $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Greater-Equal - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a >= 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bltz $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} + +TEST_CASE("Branch (Var vs. 0) - Less-Equal - S32", "[branchZero]") { + auto result = rspl::transpileSource( + R"(function test_if() { + s32<$v0> a; + if(a <= 0) { a += 1111; } else { a += 2222; } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_if: + bgtz $v0, LABEL_test_if_0001 + nop + addiu $v0, $v0, 1111 + beq $zero, $zero, LABEL_test_if_0002 + nop + LABEL_test_if_0001: + addiu $v0, $v0, 2222 + LABEL_test_if_0002: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_builtins.cpp b/cpp/tests/test_builtins.cpp new file mode 100644 index 0000000..170e735 --- /dev/null +++ b/cpp/tests/test_builtins.cpp @@ -0,0 +1,334 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Builtins - swap() - scalar", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 v0, v1; + swap(v0, v1); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + xor $t0, $t0, $t1 + xor $t1, $t0, $t1 + xor $t0, $t0, $t1 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - swap() - vec16", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16 v0, v1; + swap(v0, v1); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vxor $v01, $v01, $v02 + vxor $v02, $v01, $v02 + vxor $v01, $v01, $v02 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - swap() - vec32", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32 v0, v1; + swap(v0, v1); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vxor $v01, $v01, $v03 + vxor $v03, $v01, $v03 + vxor $v01, $v01, $v03 + vxor $v02, $v02, $v04 + vxor $v04, $v02, $v04 + vxor $v02, $v02, $v04 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - asm_include()", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 a = 4; + asm_include("./test.inc"); + a = 5; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $t0, $zero, 4 + #define zero $0 + #define v0 $2 + #define v1 $3 + #define a0 $4 + #define a1 $5 + #define a2 $6 + #define a3 $7 + #define t0 $8 + #define t1 $9 + #define t2 $10 + #define t3 $11 + #define t4 $12 + #define t5 $13 + #define t6 $14 + #define t7 $15 + #define s0 $16 + #define s1 $17 + #define s2 $18 + #define s3 $19 + #define s4 $20 + #define s5 $21 + #define s6 $22 + #define s7 $23 + #define t8 $24 + #define t9 $25 + #define k0 $26 + #define k1 $27 + #define gp $28 + #define sp $29 + #define fp $30 + #define ra $31 + .set at + .set macro + #include "./test.inc" + .set noreorder + .set noat + .set nomacro + #undef zero + #undef at + #undef v0 + #undef v1 + #undef a0 + #undef a1 + #undef a2 + #undef a3 + #undef t0 + #undef t1 + #undef t2 + #undef t3 + #undef t4 + #undef t5 + #undef t6 + #undef t7 + #undef s0 + #undef s1 + #undef s2 + #undef s3 + #undef s4 + #undef s5 + #undef s6 + #undef s7 + #undef t8 + #undef t9 + #undef k0 + #undef k1 + #undef gp + #undef sp + #undef fp + #undef ra + addiu $t0, $zero, 5 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - transpose() - 8x8 in-place", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + u16 buff; + v0 = transpose(v0, buff, 8, 8); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 2, 16, $t0 + stv $v08, 4, 32, $t0 + stv $v08, 6, 48, $t0 + stv $v08, 8, 64, $t0 + stv $v08, 10, 80, $t0 + stv $v08, 12, 96, $t0 + stv $v08, 14, 112, $t0 + ltv $v08, 14, 16, $t0 + ltv $v08, 12, 32, $t0 + ltv $v08, 10, 48, $t0 + ltv $v08, 8, 64, $t0 + ltv $v08, 6, 80, $t0 + ltv $v08, 4, 96, $t0 + ltv $v08, 2, 112, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - transpose() - 8x8 src-target", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + vec16<$v16> v1; + u16 buff; + v1 = transpose(v0, buff, 8, 8); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 0, 0, $t0 + stv $v08, 2, 16, $t0 + stv $v08, 4, 32, $t0 + stv $v08, 6, 48, $t0 + stv $v08, 8, 64, $t0 + stv $v08, 10, 80, $t0 + stv $v08, 12, 96, $t0 + stv $v08, 14, 112, $t0 + ltv $v16, 14, 16, $t0 + ltv $v16, 12, 32, $t0 + ltv $v16, 10, 48, $t0 + ltv $v16, 8, 64, $t0 + ltv $v16, 6, 80, $t0 + ltv $v16, 4, 96, $t0 + ltv $v16, 2, 112, $t0 + ltv $v16, 0, 0, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - transpose() - 4x4 in-place", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + u16 buff; + v0 = transpose(v0, buff, 4, 4); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 2, 16, $t0 + stv $v08, 4, 32, $t0 + stv $v08, 6, 48, $t0 + stv $v08, 10, 80, $t0 + stv $v08, 12, 96, $t0 + stv $v08, 14, 112, $t0 + ltv $v08, 14, 16, $t0 + ltv $v08, 12, 32, $t0 + ltv $v08, 10, 48, $t0 + ltv $v08, 6, 80, $t0 + ltv $v08, 4, 96, $t0 + ltv $v08, 2, 112, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - transpose() - 4x4 src-target", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + vec16<$v16> v1; + u16 buff; + v1 = transpose(v0, buff, 4, 4); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 0, 0, $t0 + stv $v08, 2, 16, $t0 + stv $v08, 4, 32, $t0 + stv $v08, 6, 48, $t0 + stv $v08, 10, 80, $t0 + stv $v08, 12, 96, $t0 + stv $v08, 14, 112, $t0 + ltv $v16, 14, 16, $t0 + ltv $v16, 12, 32, $t0 + ltv $v16, 10, 48, $t0 + ltv $v16, 6, 80, $t0 + ltv $v16, 4, 96, $t0 + ltv $v16, 2, 112, $t0 + ltv $v16, 0, 0, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - abs() - 16bit", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + vec16<$v16> v1; + TEST_0: + v0 = abs(v1); + TEST_1: + v1 = abs(v1); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + TEST_0: + vabs $v08, $v16, $v16 + TEST_1: + vabs $v16, $v16, $v16 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - clip() - vec32", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a; + u32<$t0> res; + res = clip(a, a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vch $v29, $v01, $v01 + vcl $v29, $v02, $v02 + cfc2 $t0, $vcc + jr $ra + nop)"); +} + +TEST_CASE("Builtins - clip() - vec16", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + u32<$t0> res; + res = clip(a, a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vch $v29, $v01, $v01 + cfc2 $t0, $vcc + jr $ra + nop)"); +} + +TEST_CASE("Builtins - get_ticks()", "[builtins]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 a; + a = get_ticks(); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mfc0 $t0, COP0_DP_CLOCK + jr $ra + nop)"); +} diff --git a/cpp/tests/test_builtinsAll.cpp b/cpp/tests/test_builtinsAll.cpp new file mode 100644 index 0000000..55a2192 --- /dev/null +++ b/cpp/tests/test_builtinsAll.cpp @@ -0,0 +1,576 @@ +#include +#include +#include "pipeline.h" + +#include + +#define TRANSPILE(src) \ + rspl::transpileSource(src, {.rspqWrapper = false}) + +#define REQUIRE_NO_WARN(r) REQUIRE(r.warn.empty()) +#define REQUIRE_ASM(r, expected) REQUIRE(r.asm_ == expected) +#define REQUIRE_THROWS_MSG(src, msg) \ + REQUIRE_THROWS_WITH(rspl::transpileSource(src, {.rspqWrapper = false}), \ + Catch::Matchers::ContainsSubstring(msg)) + +// ========================================================================== +// MFC0 Reads — parameterized via macro +// ========================================================================== + +#define MFC0_READ_TESTS(name, cop0Reg) \ + TEST_CASE("Builtins - " name "() - basic", "[mfc0_reads]") { \ + auto r = TRANSPILE("function test() {\n u32<$t0> a = " name \ + "();\n }"); \ + REQUIRE_NO_WARN(r); \ + REQUIRE_ASM(r, "test:\n mfc0 $t0, " cop0Reg "\n jr $ra\n nop"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with no left side", \ + "[mfc0_reads]") { \ + REQUIRE_THROWS_MSG("function test() {\n " name \ + "();\n }", \ + "must have a left side"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with arguments", "[mfc0_reads]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n u32<$t0> a = " name \ + "(42);\n }", \ + "requires no arguments"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with vector left side", \ + "[mfc0_reads]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n vec16<$v01> a = " name \ + "();\n }", \ + "scalar variable"); \ + } + +MFC0_READ_TESTS("get_dma_busy", "COP0_DMA_BUSY") +MFC0_READ_TESTS("get_rdp_start", "COP0_DP_START") +MFC0_READ_TESTS("get_rdp_end", "COP0_DP_END") +MFC0_READ_TESTS("get_rdp_current", "COP0_DP_CURRENT") + +// ========================================================================== +// MTC0 Writes — parameterized via macro +// ========================================================================== + +#define MTC0_WRITE_TESTS(name, cop0Reg) \ + TEST_CASE("Builtins - " name "() - basic - scalar variable", \ + "[mtc0_writes]") { \ + auto r = TRANSPILE("function test() {\n u32<$t0> a;\n " name \ + "(a);\n }"); \ + REQUIRE_NO_WARN(r); \ + REQUIRE_ASM(r, "test:\n mtc0 $t0, " cop0Reg "\n jr $ra\n nop"); \ + } \ + TEST_CASE("Builtins - " name "() - basic - literal", "[mtc0_writes]") { \ + auto r = TRANSPILE("function test() {\n " name "(42);\n }"); \ + REQUIRE_NO_WARN(r); \ + REQUIRE_ASM(r, "test:\n addiu $at, $zero, 42\n mtc0 $at, " cop0Reg \ + "\n jr $ra\n nop"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with left side", \ + "[mtc0_writes]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n u32<$t0> a = " name \ + "(42);\n }", \ + "must not have a left side"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with no argument", \ + "[mtc0_writes]") { \ + REQUIRE_THROWS_MSG("function test() {\n " name "();\n }", \ + "requires 1 scalar"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with vector argument", \ + "[mtc0_writes]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n vec16<$v01> a;\n " name \ + "(a);\n }", \ + "scalar argument"); \ + } + +MTC0_WRITE_TESTS("set_rdp_start", "COP0_DP_START") +MTC0_WRITE_TESTS("set_rdp_end", "COP0_DP_END") +MTC0_WRITE_TESTS("set_rdp_current", "COP0_DP_CURRENT") +MTC0_WRITE_TESTS("set_dma_addr_rsp", "COP0_DMA_SPADDR") +MTC0_WRITE_TESTS("set_dma_addr_rdram", "COP0_DMA_RAMADDR") +MTC0_WRITE_TESTS("set_dma_write", "COP0_DMA_WRITE") +MTC0_WRITE_TESTS("set_dma_read", "COP0_DMA_READ") + +// ========================================================================== +// clear_vcc +// ========================================================================== + +TEST_CASE("Builtins - clear_vcc() - basic", "[vcc]") { + auto r = TRANSPILE(R"(function test() { + clear_vcc(); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + vsubc $v29, $v00, $v00 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - clear_vcc() - fails with left side", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = clear_vcc(); + })", + "must not have a left side"); +} + +TEST_CASE("Builtins - clear_vcc() - fails with arguments", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + clear_vcc(42); + })", + "requires no arguments"); +} + +// ========================================================================== +// get_acc +// ========================================================================== + +TEST_CASE("Builtins - get_acc() - basic", "[get_acc]") { + auto r = TRANSPILE(R"(function test() { + vec32<$v01> a = get_acc(); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + vsar $v01, COP2_ACC_HI + vsar $v02, COP2_ACC_MD + jr $ra + nop)"); +} + +TEST_CASE("Builtins - get_acc() - fails with no left side", "[get_acc]") { + REQUIRE_THROWS_MSG(R"(function test() { + get_acc(); + })", + "must have a left side"); +} + +TEST_CASE("Builtins - get_acc() - fails with arguments", "[get_acc]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec32<$v01> a = get_acc(42); + })", + "requires no arguments"); +} + +TEST_CASE("Builtins - get_acc() - fails with scalar left side", + "[get_acc]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = get_acc(); + })", + "vector variable"); +} + +TEST_CASE("Builtins - get_acc() - fails with vec16 left side", "[get_acc]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a = get_acc(); + })", + "vec32"); +} + +// ========================================================================== +// get_acc_high / get_acc_mid / get_acc_low +// ========================================================================== + +#define ACC_SINGLE_TESTS(name, cop2Reg) \ + TEST_CASE("Builtins - " name "() - basic", "[acc_single]") { \ + auto r = TRANSPILE("function test() {\n vec16<$v03> a = " name \ + "();\n }"); \ + REQUIRE_NO_WARN(r); \ + REQUIRE_ASM(r, "test:\n vsar $v03, " cop2Reg "\n jr $ra\n nop"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with scalar left side", \ + "[acc_single]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n u32<$t0> a = " name "();\n }", \ + "vector variable"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with vec32 left side", \ + "[acc_single]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n vec32<$v01> a = " name \ + "();\n }", \ + "vec16"); \ + } \ + TEST_CASE("Builtins - " name "() - fails with arguments", \ + "[acc_single]") { \ + REQUIRE_THROWS_MSG( \ + "function test() {\n vec16<$v01> a = " name \ + "(1);\n }", \ + "requires no arguments"); \ + } + +ACC_SINGLE_TESTS("get_acc_high", "COP2_ACC_HI") +ACC_SINGLE_TESTS("get_acc_mid", "COP2_ACC_MD") +ACC_SINGLE_TESTS("get_acc_low", "COP2_ACC_LO") + +// ========================================================================== +// get_vcc +// ========================================================================== + +TEST_CASE("Builtins - get_vcc() - basic", "[vcc]") { + auto r = TRANSPILE(R"(function test() { + u32<$t0> a = get_vcc(); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + cfc2 $t0, $vcc + jr $ra + nop)"); +} + +TEST_CASE("Builtins - get_vcc() - fails with no left side", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + get_vcc(); + })", + "must have a left side"); +} + +TEST_CASE("Builtins - get_vcc() - fails with arguments", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = get_vcc(1); + })", + "requires no arguments"); +} + +TEST_CASE("Builtins - get_vcc() - fails with vector left side", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a = get_vcc(); + })", + "scalar variable"); +} + +// ========================================================================== +// set_vcc +// ========================================================================== + +TEST_CASE("Builtins - set_vcc() - basic - scalar", "[vcc]") { + auto r = TRANSPILE(R"(function test() { + u32<$t0> a; + set_vcc(a); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + ctc2 $t0, $vcc + jr $ra + nop)"); +} + +TEST_CASE("Builtins - set_vcc() - basic - literal", "[vcc]") { + auto r = TRANSPILE(R"(function test() { + set_vcc(1); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + addiu $at, $zero, 1 + ctc2 $at, $vcc + jr $ra + nop)"); +} + +TEST_CASE("Builtins - set_vcc() - fails with left side", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = set_vcc(1); + })", + "must not have a left side"); +} + +TEST_CASE("Builtins - set_vcc() - fails with no argument", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + set_vcc(); + })", + "requires 1 scalar"); +} + +TEST_CASE("Builtins - set_vcc() - fails with vector argument", "[vcc]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a; + set_vcc(a); + })", + "scalar argument"); +} + +// ========================================================================== +// Inline ASM — asm() and asm_op() +// ========================================================================== + +TEST_CASE("Builtins - asm() - basic - string only", "[asm]") { + auto r = TRANSPILE(R"(function test() { + asm("nop"); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + nop # inline-ASM + jr $ra + nop)"); +} + +TEST_CASE("Builtins - asm() - with substitution - number", "[asm]") { + auto r = TRANSPILE(R"(function test() { + asm("addiu $t0, $zero, %0", 42); + })"); + REQUIRE_NO_WARN(r); + REQUIRE(r.asm_.find("addiu $t0, $zero, 42") != std::string::npos); + REQUIRE(r.asm_.find("# inline-ASM") != std::string::npos); +} + +TEST_CASE("Builtins - asm() - with substitution - register", "[asm]") { + auto r = TRANSPILE(R"(function test() { + u32<$t1> a; + asm("addu $t0, $zero, %0", a); + })"); + REQUIRE_NO_WARN(r); + REQUIRE(r.asm_.find("addu $t0, $zero, $t1") != std::string::npos); + REQUIRE(r.asm_.find("# inline-ASM") != std::string::npos); +} + +TEST_CASE("Builtins - asm() - fails with left side", "[asm]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = asm("nop"); + })", + "cannot have a left side"); +} + +TEST_CASE("Builtins - asm() - fails with first arg not string", "[asm]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a; + asm(42); + })", + "first argument to be a string"); +} + +TEST_CASE("Builtins - asm_op() - basic", "[asm_op]") { + auto r = TRANSPILE(R"(function test() { + u32<$t0> a; + asm_op("mtc0", a); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + mtc0 $t0 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - asm_op() - fails with left side", "[asm_op]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = asm_op("nop"); + })", + "cannot have a left side"); +} + +TEST_CASE("Builtins - asm_op() - fails with first arg not string", + "[asm_op]") { + REQUIRE_THROWS_MSG(R"(function test() { + asm_op(42); + })", + "opcode"); +} + +// ========================================================================== +// dma_await +// ========================================================================== + +TEST_CASE("Builtins - dma_await() - basic", "[dma]") { + auto r = TRANSPILE(R"(function test() { + dma_await(); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + jal DMAWaitIdle + nop + jr $ra + nop)"); +} + +TEST_CASE("Builtins - dma_await() - fails with left side", "[dma]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = dma_await(); + })", + "cannot have a left side"); +} + +TEST_CASE("Builtins - dma_await() - fails with arguments", "[dma]") { + REQUIRE_THROWS_MSG(R"(function test() { + dma_await(1); + })", + "requires no arguments"); +} + +// ========================================================================== +// get_cmd_address +// ========================================================================== + +TEST_CASE("Builtins - get_cmd_address() - basic with offset", "[cmd]") { + auto r = TRANSPILE(R"(function test() { + u32<$t1> a = get_cmd_address(12); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + addiu $t1, $gp, %lo(RSPQ_DMEM_BUFFER) + 12 + jr $ra + nop)"); +} + +TEST_CASE("Builtins - get_cmd_address() - fails with no left side", + "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + get_cmd_address(); + })", + "must have a left side"); +} + +TEST_CASE("Builtins - get_cmd_address() - fails with too many arguments", + "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = get_cmd_address(1, 2); + })", + "zero or one argument"); +} + +TEST_CASE("Builtins - get_cmd_address() - fails with non-number argument", + "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a; + u32<$t1> b = get_cmd_address(a); + })", + "number"); +} + +TEST_CASE("Builtins - get_cmd_address() - fails with vector left side", + "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a = get_cmd_address(); + })", + "scalar variable"); +} + +// ========================================================================== +// load_arg +// ========================================================================== + +TEST_CASE("Builtins - load_arg() - basic with offset", "[cmd]") { + auto r = TRANSPILE(R"(function test() { + u32<$t1> a = load_arg(8); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + lw $t1, %lo(RSPQ_DMEM_BUFFER + 8)($gp) + jr $ra + nop)"); +} + +TEST_CASE("Builtins - load_arg() - fails with no left side", "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + load_arg(); + })", + "must have a left side"); +} + +TEST_CASE("Builtins - load_arg() - fails with too many arguments", "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = load_arg(1, 2); + })", + "zero or one argument"); +} + +TEST_CASE("Builtins - load_arg() - fails with non-number argument", "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a; + u32<$t1> b = load_arg(a); + })", + "number"); +} + +TEST_CASE("Builtins - load_arg() - fails with vector left side", "[cmd]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a = load_arg(); + })", + "scalar variable"); +} + +// ========================================================================== +// max / min +// ========================================================================== + +TEST_CASE("Builtins - max() - basic", "[maxmin]") { + auto r = TRANSPILE(R"(function test() { + vec16<$v01> a, b; + vec16<$v03> c = max(a, b); + })"); + REQUIRE_NO_WARN(r); + REQUIRE(r.asm_.find("vge") != std::string::npos); +} + +TEST_CASE("Builtins - max() - fails with wrong arg count", "[maxmin]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a; + vec16<$v03> c = max(a); + })", + "exactly two arguments"); +} + +TEST_CASE("Builtins - max() - fails with type mismatch", "[maxmin]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a; + vec32<$v03> b; + vec16<$v05> c = max(a, b); + })", + "can only use vec16"); +} + +TEST_CASE("Builtins - min() - basic", "[maxmin]") { + auto r = TRANSPILE(R"(function test() { + vec16<$v01> a, b; + vec16<$v03> c = min(a, b); + })"); + REQUIRE_NO_WARN(r); + REQUIRE(r.asm_.find("vlt") != std::string::npos); +} + +TEST_CASE("Builtins - min() - fails with wrong arg count", "[maxmin]") { + REQUIRE_THROWS_MSG(R"(function test() { + vec16<$v01> a; + vec16<$v03> c = min(a); + })", + "exactly two arguments"); +} + +// ========================================================================== +// assert +// ========================================================================== + +TEST_CASE("Builtins - assert() - basic", "[assert]") { + auto r = TRANSPILE(R"(function test() { + assert(42); + })"); + REQUIRE_NO_WARN(r); + REQUIRE_ASM(r, R"(test: + lui $at, 42 + j assertion_failed + nop + jr $ra + nop)"); +} + +TEST_CASE("Builtins - assert() - fails with left side", "[assert]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a = assert(1); + })", + "cannot have a left side"); +} + +TEST_CASE("Builtins - assert() - fails with wrong arg count", "[assert]") { + REQUIRE_THROWS_MSG(R"(function test() { + assert(1, 2); + })", + "exactly one argument"); +} + +TEST_CASE("Builtins - assert() - fails with non-number argument", "[assert]") { + REQUIRE_THROWS_MSG(R"(function test() { + u32<$t0> a; + assert(a); + })", + "number"); +} \ No newline at end of file diff --git a/cpp/tests/test_builtinsDebug.cpp b/cpp/tests/test_builtinsDebug.cpp new file mode 100644 index 0000000..87d2da5 --- /dev/null +++ b/cpp/tests/test_builtinsDebug.cpp @@ -0,0 +1,280 @@ +#include +#include +#include "pipeline.h" + +#include + +// --- set_rsp_status --- + +TEST_CASE("Builtins - Debug - set_rsp_status() - scalar", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$t0> a; + set_rsp_status(a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mtc0 $t0, COP0_SP_STATUS + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - set_rsp_status() - scalar literal", + "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + set_rsp_status(42); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $at, $zero, 42 + mtc0 $at, COP0_SP_STATUS + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - set_rsp_status() - fails with no argument", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + set_rsp_status(); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("requires 1 scalar")); +} + +TEST_CASE("Builtins - Debug - set_rsp_status() - fails with left side", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a = set_rsp_status(); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("must not have a left side")); +} + +TEST_CASE("Builtins - Debug - set_rsp_status() - fails with vector", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + set_rsp_status(a); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("scalar argument")); +} + +// --- print --- + +TEST_CASE("Builtins - Debug - print() - scalar", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$t0> a; + u32<$t1> b; + print(a, b); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + .set macro # print + emux_dump_gpr $t0, $t1 + .set noat # print + .set nomacro # print + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - print() - vector", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec16<$v03> b; + print(a, b); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + .set macro # print + emux_dump_vpr $v01, $v03 + .set noat # print + .set nomacro # print + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - print() - string", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + print("hello", "world"); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + .set macro # print + emux_log_string "hello", "world" + .set noat # print + .set nomacro # print + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - print() - fails with no arguments", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + print(); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("requires at least one argument")); +} + +TEST_CASE("Builtins - Debug - print() - fails with left side", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a = print(); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("cannot have a left side")); +} + +TEST_CASE("Builtins - Debug - print() - fails with mixed types", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + print(a, "hello"); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("same type")); +} + +TEST_CASE("Builtins - Debug - print() - fails with number literal", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + print(42); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("variables or strings")); +} + +TEST_CASE("Builtins - Debug - print() - fails with mixed scalar/vector", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + vec16<$v01> b; + print(a, b); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("mixed scalar/vector")); +} + +// --- printf --- + +TEST_CASE("Builtins - Debug - printf() - basic scalar", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$t0> a; + printf("hello %d", a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + .set macro # print + emux_printf "hello %dt0" + .set noat # print + .set nomacro # print + jr $ra + nop)"); +} + +TEST_CASE("Builtins - Debug - printf() - vec32 with swizzle", + "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a; + printf("result %f", a.x); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find("emux_printf") != std::string::npos); + REQUIRE(result.asm_.find("result %f") != std::string::npos); +} + +TEST_CASE("Builtins - Debug - printf() - vec16", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v03> a; + printf("val %d", a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find("emux_printf") != std::string::npos); + REQUIRE(result.asm_.find("val %d") != std::string::npos); +} + +TEST_CASE("Builtins - Debug - printf() - string only", "[builtinsDebug]") { + auto result = rspl::transpileSource( + R"(function test() { + printf("hello world"); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find("emux_printf") != std::string::npos); + REQUIRE(result.asm_.find("hello world") != std::string::npos); +} + +TEST_CASE("Builtins - Debug - printf() - fails with no arguments", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + printf(); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("requires at least one argument")); +} + +TEST_CASE("Builtins - Debug - printf() - fails with left side", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a = printf("hello"); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("cannot have a left side")); +} + +TEST_CASE("Builtins - Debug - printf() - fails with non-string first arg", + "[builtinsDebug]") { + REQUIRE_THROWS_WITH( + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + printf(a); + })", + {.rspqWrapper = false}), + Catch::Matchers::ContainsSubstring("first argument to be a string")); +} \ No newline at end of file diff --git a/cpp/tests/test_compare.cpp b/cpp/tests/test_compare.cpp new file mode 100644 index 0000000..6a096b8 --- /dev/null +++ b/cpp/tests/test_compare.cpp @@ -0,0 +1,189 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Comparison - Vector (vec16 vs vec16)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a, b; + res = a < b; + res = a >= b; + res = a == b; + res = a != b; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vlt $v01, $v02, $v03 + vge $v01, $v02, $v03 + veq $v01, $v02, $v03 + vne $v01, $v02, $v03 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector (vec16 vs const)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a, b; + res = a < 0; + res = a >= 2; + res = a == 32; + res = a != 256; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vlt $v01, $v02, $v00.e0 + vge $v01, $v02, $v30.e6 + veq $v01, $v02, $v30.e2 + vne $v01, $v02, $v31.e7 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Select (vec16)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a, b; + res = select(a, b); + res = select(a, 32); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmrg $v01, $v02, $v03 + vmrg $v01, $v02, $v30.e2 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Select (vec32)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a, b; + A: + res = select(a, b); + B: + res = select(a, b.y); + C: + res = select(a, 32); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + A: + vmrg $v01, $v03, $v05 + vmrg $v02, $v04, $v06 + B: + vmrg $v01, $v03, $v05.e1 + vmrg $v02, $v04, $v06.e1 + C: + vmrg $v01, $v03, $v30.e2 + vmrg $v02, $v04, $v00.e2 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Select (vec32 cast)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a, b; + + res:sint = select(a, b:sfract); + res:sfract = select(a, 32); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmrg $v01, $v03, $v06 + vmrg $v02, $v04, $v00.e2 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Ternary (vec16)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a, b; + vec16<$v10> x, y; + + A: + res = x != y ? a : b; + B: + res = x != 4 ? a : 32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + A: + vne $v29, $v10, $v11 + vmrg $v01, $v02, $v03 + B: + vne $v29, $v10, $v30.e5 + vmrg $v01, $v02, $v30.e2 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Ternary (vec32)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a, b; + vec16<$v10> x, y; + + A: + res = x != y ? a : b; + B: + res = x != 4 ? a : 32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + A: + vne $v29, $v10, $v11 + vmrg $v01, $v03, $v05 + vmrg $v02, $v04, $v06 + B: + vne $v29, $v10, $v30.e5 + vmrg $v01, $v03, $v30.e2 + vmrg $v02, $v04, $v00.e2 + jr $ra + nop)"); +} + +TEST_CASE("Comparison - Vector-Ternary (swizzle)", "[compare]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a, b; + A: + res = a == b ? a : b.y; + B: + res = a >= b.z ? a : b.y; + C: + res = a == b.z ? a : b; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + A: + veq $v29, $v02, $v03 + vmrg $v01, $v02, $v03.e1 + B: + vge $v29, $v02, $v03.e2 + vmrg $v01, $v02, $v03.e1 + C: + veq $v29, $v02, $v03.e2 + vmrg $v01, $v02, $v03 + jr $ra + nop)"); +} diff --git a/cpp/tests/test_const.cpp b/cpp/tests/test_const.cpp new file mode 100644 index 0000000..0415abe --- /dev/null +++ b/cpp/tests/test_const.cpp @@ -0,0 +1,67 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Const - Declaration", "[const]") { + auto result = rspl::transpileSource( + R"(function test() +{ + const u32<$t0> a = 1234; + const u32<$t1> b = a + a; + const vec16<$v01> c = 0; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $t0, $zero, 1234 + addu $t1, $t0, $t0 + vxor $v01, $v00, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Const - Invalid (scalar reassignment)", "[const]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + const u32<$t0> a = 1234; + a += 1; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + const u32<$t0> a = 1234; + a += 1; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Cannot assign to constant variable") != std::string::npos); + } +} + +TEST_CASE("Const - Invalid (vector reassignment)", "[const]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + const vec16<$v01> a = 0; + a += a; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + const vec16<$v01> a = 0; + a += a; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Cannot assign to constant variable") != std::string::npos); + } +} diff --git a/cpp/tests/test_control.cpp b/cpp/tests/test_control.cpp new file mode 100644 index 0000000..5e2998e --- /dev/null +++ b/cpp/tests/test_control.cpp @@ -0,0 +1,18 @@ +#include +#include "pipeline.h" + +TEST_CASE("Control - Exit", "[control]") { + auto result = rspl::transpileSource( + R"(function test() +{ + exit; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + j RSPQ_Loop + nop + jr $ra + nop)"); +} diff --git a/cpp/tests/test_debugInfo.cpp b/cpp/tests/test_debugInfo.cpp new file mode 100644 index 0000000..36b9b09 --- /dev/null +++ b/cpp/tests/test_debugInfo.cpp @@ -0,0 +1,97 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Debug Info - Basic operations", "[debugInfo]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 a = 0; + a += 42; + })", + {.rspqWrapper = false, .debugInfo = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero ## L:2 | ^ | u32 a = 0; + addiu $t0, $t0, 42 ## L:3 | 2 | a += 42; + jr $ra ## L:4 | 3 | } + nop ## L:4 | *5 | })"); +} + +TEST_CASE("Debug Info - @Tag annotation prefix on instruction", + "[debugInfo]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 a; + @Tag("Foo") a = 1; + })", + {.rspqWrapper = false, .debugInfo = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + TAG_Foo: addiu $t0, $zero, 1 ## L:3 | ^ | @Tag("Foo") a = 1; + jr $ra ## L:4 | 2 | } + nop ## L:4 | *4 | })"); +} + +TEST_CASE("Debug Info - @Tag annotation on used label", + "[debugInfo]") { + auto result = rspl::transpileSource( + R"(function test() { + @Tag("Start") LOOP: + u32 a = 1; + goto LOOP; + })", + {.rspqWrapper = false, .debugInfo = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + TAG_Start: LOOP: + addiu $t0, $zero, 1 ## L:3 | 1 | u32 a = 1; + j LOOP ## L:4 | 2 | goto LOOP; + nop ## L:4 | *4 | goto LOOP;)"); +} + +TEST_CASE("Debug Info - Transpose builtin shares line and barrier info", + "[debugInfo]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v08> v0; + u16 buff; + v0 = transpose(v0, buff, 4, 4); + })", + {.rspqWrapper = false, .debugInfo = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 2, 16, $t0 ## L:4 | ^ | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + stv $v08, 4, 32, $t0 ## L:4 | 2 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + stv $v08, 6, 48, $t0 ## L:4 | 3 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + stv $v08, 10, 80, $t0 ## L:4 | 4 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + stv $v08, 12, 96, $t0 ## L:4 | 5 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + stv $v08, 14, 112, $t0 ## L:4 | 6 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 14, 16, $t0 ## L:4 | 7 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 12, 32, $t0 ## L:4 | 8 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 10, 48, $t0 ## L:4 | 9 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 6, 80, $t0 ## L:4 | 10 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 4, 96, $t0 ## L:4 | 11 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + ltv $v08, 2, 112, $t0 ## L:4 | 12 | v0 = transpose(v0, buff, 4, 4); ## Barrier: 0x1 + jr $ra ## L:5 | 13 | } + nop ## L:5 | *15 | })"); +} + +TEST_CASE("Debug Info - Disabled produces no padding or comments", + "[debugInfo]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 a = 0; + })", + {.rspqWrapper = false, .debugInfo = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + jr $ra + nop)"); +} diff --git a/cpp/tests/test_defineAsm.cpp b/cpp/tests/test_defineAsm.cpp new file mode 100644 index 0000000..744f52a --- /dev/null +++ b/cpp/tests/test_defineAsm.cpp @@ -0,0 +1,32 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Define (ASM) - Define in ASM", "[defineAsm]") { + auto result = rspl::transpileSource( + R"( + include "rsp_queue.inc" + include "rdpq_macros.h" + + #define SOME_DEF_A 1 + #define SOME_DEF_B 2 + + state{} + + #define SOME_DEF_C 3 + + command<0> test(u32 a) + { + } + + #define SOME_DEF_D 4 + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find("#define SOME_DEF_A 1") != std::string::npos); + REQUIRE(result.asm_.find("#define SOME_DEF_B 2") != std::string::npos); + REQUIRE(result.asm_.find("#define SOME_DEF_C 3") != std::string::npos); + REQUIRE(result.asm_.find("#define SOME_DEF_D 4") != std::string::npos); +} diff --git a/cpp/tests/test_dma.cpp b/cpp/tests/test_dma.cpp new file mode 100644 index 0000000..6f0e67f --- /dev/null +++ b/cpp/tests/test_dma.cpp @@ -0,0 +1,94 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("DMA - dma_in (sync) with explicit size", "[dma]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$s4> dest = 0x1000; + u32<$s0> rdram = 0x2000; + u32<$t0> size = 32; + dma_in(dest, rdram, size); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $s4, $zero, 4096 + addiu $s0, $zero, 8192 + addiu $t0, $zero, 32 + addiu $t0, $t0, -1 + addiu $t2, $zero, 12 + jal DMAExec + nop + jr $ra + nop)"); +} + +TEST_CASE("DMA - dma_out_async with explicit size", "[dma]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$s4> dest = 0x1000; + u32<$s0> rdram = 0x2000; + u32<$t0> size = 64; + dma_out_async(dest, rdram, size); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $s4, $zero, 4096 + addiu $s0, $zero, 8192 + addiu $t0, $zero, 64 + addiu $t0, $t0, -1 + addiu $t2, $zero, -32768 + jal DMAExec + nop + jr $ra + nop)"); +} + +TEST_CASE("DMA - dma_in_async with memory dest (2-arg)", "[dma]") { + auto result = rspl::transpileSource( + R"(state { u8 BUFF[64]; } +function test() { + u32<$s0> rdram = 0x2000; + dma_in_async(BUFF, rdram); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $s0, $zero, 8192 + ori $t0, $zero, 63 + ori $s4, $zero, %lo(BUFF) + or $t2, $zero, $zero + jal DMAExec + nop + jr $ra + nop)"); +} + +TEST_CASE("DMA - dma_in_async with register dest (3-arg)", "[dma]") { + auto result = rspl::transpileSource( + R"(function test() { + u32<$s4> dest = 0x1000; + u32<$s0> rdram = 0x2000; + u32<$t0> size = 64; + dma_in_async(dest, rdram, size); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $s4, $zero, 4096 + addiu $s0, $zero, 8192 + addiu $t0, $zero, 64 + addiu $t0, $t0, -1 + or $t2, $zero, $zero + jal DMAExec + nop + jr $ra + nop)"); +} diff --git a/cpp/tests/test_evalCost.cpp b/cpp/tests/test_evalCost.cpp new file mode 100644 index 0000000..a2ff542 --- /dev/null +++ b/cpp/tests/test_evalCost.cpp @@ -0,0 +1,281 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" +#include "optimizer/eval_cost.h" + +#include +#include +#include + +using namespace rspl; + +// Parse text lines like "or $t0, $zero, $zero" into AsmInst vectors +static std::vector textToAsmLines(const std::string &text) { + std::vector lines; + std::istringstream ss(text); + std::string line; + while (std::getline(ss, line)) { + // trim + size_t start = line.find_first_not_of(" \t"); + if (start == std::string::npos) continue; + size_t end = line.find_last_not_of(" \t"); + line = line.substr(start, end - start + 1); + if (line.empty()) continue; + + std::istringstream ls(line); + std::string op; + ls >> op; + std::vector args; + std::string arg; + while (ls >> arg) { + if (arg.back() == ',') arg.pop_back(); + args.push_back(arg); + } + if (op == "nop") + lines.push_back(asmNOP()); + else + lines.push_back(asmOp(op, args)); + } + return lines; +} + +static std::vector linesToCycles(std::vector &lines) { + AsmFunc func; + func.asm_ = std::move(lines); + asmInitDeps(func); + evalFunctionCost(func); + std::vector cycles; + for (const auto &inst : func.asm_) + cycles.push_back(inst.debug.cycle); + return cycles; +} + +#define CHECK_CYCLES(name, text, ...) \ + TEST_CASE("Eval - Cost - " name, "[evalCost]") { \ + auto lines = textToAsmLines(text); \ + std::vector expected = __VA_ARGS__; \ + auto cycles = linesToCycles(lines); \ + REQUIRE(cycles == expected); \ + } + +CHECK_CYCLES("SU only - no dep", + "or $t0, $zero, $zero\n" + "addiu $t0, $t0, 1\n" + "addiu $t0, $t0, 1\n" + "addiu $t0, $t0, 1\n" + "jr $ra\n" + "nop\n", + {1, 2, 3, 4, 5, 7}) + +CHECK_CYCLES("SU only - deps", + "or $t0, $zero, $zero\n" + "or $t1, $zero, $t0\n" + "addu $t2, $t0, $t1\n", + {1, 2, 3}) + +CHECK_CYCLES("VU only - no dep", + "vxor $v01, $v00, $v00.e0\n" + "vxor $v01, $v00, $v30.e7\n" + "vxor $v01, $v00, $v30.e6\n" + "vxor $v01, $v00, $v30.e5\n", + {1, 2, 3, 4}) + +CHECK_CYCLES("VU only - ACC / 32-bit mul", + "vmudl $v29, $v05, $v09.v\n" + "vmadm $v29, $v04, $v09.v\n" + "vmadn $v11, $v05, $v08.v\n" + "vmadh $v10, $v04, $v08.v\n", + {1, 2, 3, 4}) + +CHECK_CYCLES("VU only - DIV / invert_half", + "vrcph $v04.e0, $v04.e0\n" + "vrcpl $v05.e0, $v05.e0\n" + "vrcph $v04.e0, $v08.e0\n" + "vrcpl $v09.e0, $v09.e0\n" + "vrcph $v08.e0, $v00.e0\n", + {1, 2, 3, 4, 5}) + +CHECK_CYCLES("VU only - ternary", + "vaddc $v06, $v06, $v06.v\n" + "vadd $v11, $v05, $v05.v\n" + "vne $v29, $v18, $v00.e0\n" + "vmrg $v13, $v05, $v07\n" + "vmrg $v14, $v06, $v08\n", + {1, 2, 3, 4, 5}) + +CHECK_CYCLES("VU only - deps", + "vxor $v01, $v00, $v00.e0\n" + "vaddc $v04, $v01, $v01.v\n" + "vxor $v05, $v00, $v00.e0\n" + "vaddc $v04, $v01, $v01.v\n", + {1, 5, 6, 7}) + +CHECK_CYCLES("SU/VU mix - no dep", + "vxor $v01, $v00, $v00.e0\n" + "vxor $v01, $v00, $v30.e7\n" + "addiu $t0, $zero, 4\n" + "vxor $v01, $v00, $v30.e6\n" + "addiu $t1, $zero, 4\n" + "addiu $t2, $zero, 4\n", + {1, 2, 2, 3, 3, 4}) + +CHECK_CYCLES("VU - same src/dst dep", + "vxor $v01, $v00, $v30.e7\n" + "vxor $v01, $v01, $v01\n" + "vor $v05, $v00, $v01\n", + {1, 5, 9}) + +CHECK_CYCLES("VU/SU - no dual issue (lqv)", + "vand $v04, $v30, $v31\n" + "lqv $v04, 0, 0, $s4\n", + {1, 2}) + +CHECK_CYCLES("VU/SU - no dual issue (mfc2)", + "vand $v04, $v30, $v31\n" + "mfc2 $t0, $v04.e4\n", + {1, 5}) + +CHECK_CYCLES("SU/VU mix - delay slot", + "or $t0, $zero, $zero\n" + "bne $t0, $zero, END\n" + "nop\n" + "vxor $v01, $v00, $v30.e7\n" + "vaddc $v02, $v03, $v30.e7\n", + {1, 2, 4, 5, 6}) + +CHECK_CYCLES("SU/VU mix - MTC2", + "vxor $v04, $v00, $v00.e0\n" + "vxor $v05, $v00, $v00\n" + "mtc2 $t0, $v05.e0\n" + "srl $at, $t0, 16\n" + "srl $at, $t0, 16\n" + "mtc2 $at, $v04.e0\n" + "vmov $v04.e2, $v04.e0\n" + "vmov $v05.e2, $v05.e0\n" + "vmov $v04.e3, $v04.e1\n" + "vmov $v05.e3, $v05.e1\n", + {1, 2, 3, 4, 5, 6, 10, 11, 14, 15}) + +CHECK_CYCLES("SU/VU in dual - MFC2", + "vmadn $v06, $v06, $v05.h3\n" + "vmadh $v05, $v05, $v05.h3\n" + "nop\n" + "mfc2 $fp, $v02.e2\n" + "vmudl $v29, $v06, $v14.v\n" + "vor $v00, $v00, $v00\n" + "nop\n" + "sra $fp, $fp, 7\n", + {1, 2, 2, 5, 5, 6, 6, 8}) + +CHECK_CYCLES("SU memory - load (dep)", + "lhu $s3, 24($s4)\n" + "lhu $s2, 24($s4)\n" + "srl $s2, $s2, 2\n", + {1, 2, 5}) + +CHECK_CYCLES("SU memory - load/store (dep)", + "lw $t1, 0($t0)\n" + "addiu $t2, $zero, 3\n" + "sw $t1, ($t0)\n" + "addiu $t2, $zero, 4\n", + {1, 2, 4, 5}) + +CHECK_CYCLES("SU memory - load/store (no-dep)", + "lw $t4, 0($t0)\n" + "addiu $t2, $zero, 3\n" + "sw $t1, ($t0)\n" + "addiu $t2, $zero, 4\n", + {1, 2, 4, 5}) + +CHECK_CYCLES("SU memory - load/store (multiple)", + "lhu $s3, 24($s4)\n" + "lhu $s2, 24($s4)\n" + "lhu $s1, 24($s4)\n" + "sh $s3, 12($s6)\n" + "sh $s2, 12($s5)\n" + "sh $s1, 12($s5)\n", + {1, 2, 3, 6, 7, 8}) + +CHECK_CYCLES("SU memory - load/store (no-dep, dual)", + "lw $t4, 0($t0)\n" + "vxor $v11, $v11, $v11\n" + "addiu $t2, $zero, 3\n" + "sw $t1, ($t0)\n" + "addiu $t2, $zero, 4\n", + {1, 1, 2, 4, 5}) + +CHECK_CYCLES("CFC2 - stall", + "cfc2 $sp, $vcc\n" + "andi $sp, $sp, 1799\n" + "srl $t7, $sp, 5\n", + {1, 4, 5}) + +CHECK_CYCLES("VU + CFC2 - dual", + "vxor $v01, $v01, $v01\n" + "cfc2 $sp, $vcc\n" + "andi $sp, $sp, 1799\n" + "srl $t7, $sp, 5\n", + {1, 1, 4, 5}) + +CHECK_CYCLES("CFC2 + VU - dual", + "cfc2 $sp, $vcc\n" + "vxor $v01, $v01, $v01\n" + "andi $sp, $sp, 1799\n" + "srl $t7, $sp, 5\n", + {1, 1, 4, 5}) + +CHECK_CYCLES("VU + CFC2 - no-dual", + "vcl $v29, $v27, $v20\n" + "cfc2 $sp, $vcc\n" + "andi $sp, $sp, 1799\n" + "srl $t7, $sp, 5\n", + {1, 2, 5, 6}) + +CHECK_CYCLES("CFC2 + VU - dual 2", + "cfc2 $sp, $vcc\n" + "vcl $v29, $v27, $v20\n" + "andi $sp, $sp, 1799\n" + "srl $t7, $sp, 5\n", + {1, 1, 4, 5}) + +CHECK_CYCLES("Branch - NOP", + "or $s7, $zero, $zero\n" + "beq $s7, $zero, LABEL_0001\n" + "nop\n" + "vxor $v28, $v00, $v30.e7\n", + {1, 2, 4, 5}) + +CHECK_CYCLES("Branch - dual-issue", + "vxor $v00, $v00, $v00\n" + "bne $s7, $zero, LABEL_0001\n" + "nop\n" + "vxor $v28, $v00, $v30.e7\n" + "nop\n", + {1, 1, 3, 5, 5}) + +CHECK_CYCLES("Branch multiple - NOP", + "beq $zero, $zero, LABEL_A\n" + "nop\n" + "beq $zero, $zero, LABEL_B\n" + "nop\n", + {1, 3, 4, 6}) + +CHECK_CYCLES("Branch - filled (scalar)", + "or $s7, $zero, $zero\n" + "beq $s7, $zero, LABEL_0001\n" + "addiu $s6, $zero, 3\n" + "addiu $s6, $zero, 1\n", + {1, 2, 4, 5}) + +CHECK_CYCLES("Branch - filled + stall (scalar)", + "lw $a0, %lo(SCREEN_SIZE_VEC + 0)\n" + "beq $zero, $zero, LABEL_A\n" + "addiu $a0, $a0, 1\n", + {1, 2, 5}) + +CHECK_CYCLES("Branch - filled (vector)", + "or $s7, $zero, $zero\n" + "beq $s7, $zero, LABEL_0001\n" + "vxor $v28, $v00, $v30.e7\n" + "vxor $v28, $v00, $v30.e7\n", + {1, 2, 4, 5}) diff --git a/cpp/tests/test_evalCostExample.cpp b/cpp/tests/test_evalCostExample.cpp new file mode 100644 index 0000000..1de47da --- /dev/null +++ b/cpp/tests/test_evalCostExample.cpp @@ -0,0 +1,246 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" +#include "optimizer/eval_cost.h" + +#include +#include +#include +#include +#include + +using namespace rspl; + +// Parse text with bracket annotations like "[0] nop", "[^] vadd..." into +// AsmInst vectors. Lines with "# unlikely" clear the likely-branch flags. +static std::vector textToAsmLines(const std::string &text) { + std::vector lines; + std::istringstream ss(text); + std::string line; + while (std::getline(ss, line)) { + // strip leading bracket annotation [*] + auto rb = line.find(']'); + if (rb == std::string::npos) continue; + line = line.substr(rb + 1); + + // trim + size_t start = line.find_first_not_of(" \t"); + if (start == std::string::npos) continue; + size_t end = line.find_last_not_of(" \t"); + line = line.substr(start, end - start + 1); + if (line.empty()) continue; + + // remove trailing comment + auto hashPos = line.find('#'); + bool unlikely = false; + if (hashPos != std::string::npos) { + if (line.find("unlikely", hashPos) != std::string::npos) + unlikely = true; + line = line.substr(0, hashPos); + // trim again + end = line.find_last_not_of(" \t"); + if (end == std::string::npos) continue; + line = line.substr(0, end + 1); + } + + std::istringstream ls(line); + std::string op; + ls >> op; + std::vector args; + std::string arg; + while (ls >> arg) { + if (arg.back() == ',') arg.pop_back(); + args.push_back(arg); + } + AsmInst inst; + if (op == "nop") + inst = asmNOP(); + else + inst = asmOp(op, args); + + if (inst.opFlags & OpFlag::OP_FLAG_IS_BRANCH) { + if (unlikely) { + inst.opFlags &= + ~(OpFlag::OP_FLAG_LIKELY_BRANCH | OpFlag::OP_FLAG_IS_LIKELY); + } else { + inst.opFlags |= + (OpFlag::OP_FLAG_LIKELY_BRANCH | OpFlag::OP_FLAG_IS_LIKELY); + } + } + lines.push_back(std::move(inst)); + } + return lines; +} + +// Parse the bracket annotations into expected cycle numbers. +static std::vector textToAsmCycle(const std::string &text) { + std::vector annotations; + std::istringstream ss(text); + std::string line; + while (std::getline(ss, line)) { + auto lb = line.find('['); + auto rb = line.find(']'); + if (lb == std::string::npos || rb == std::string::npos) continue; + std::string a = line.substr(lb + 1, rb - lb - 1); + // trim + size_t s = a.find_first_not_of(" \t"); + if (s == std::string::npos) continue; + size_t e = a.find_last_not_of(" \t"); + a = a.substr(s, e - s + 1); + annotations.push_back(a); + } + + std::vector cycles; + int lastCycle = 0; + for (size_t i = 0; i < annotations.size(); ++i) { + int stars = static_cast( + std::count(annotations[i].begin(), annotations[i].end(), '*')); + if (!annotations[i].starts_with("^")) { + lastCycle = std::stoi(annotations[i]); + } else { + if (i > 0) cycles[i - 1] += stars; + } + lastCycle += stars; + cycles.push_back(lastCycle + 1); + } + return cycles; +} + +static std::vector linesToCycles(std::vector &lines) { + AsmFunc func; + func.asm_ = std::move(lines); + asmInitDeps(func); + evalFunctionCost(func); + std::vector cycles; + for (const auto &inst : func.asm_) + cycles.push_back(inst.debug.cycle); + return cycles; +} + +static const std::string T3D_CODE = R"( +[0] nop +[0] vmulf $v06, $v20, $v07.h0 +[1] ori $at, $zero, %lo(COLOR_AMBIENT) +[^] vmacf $v06, $v19, $v07.h1 +[2] ori $s3, $zero, %lo(LIGHT_DIR_COLOR) +[^] vmacf $v07, $v18, $v07.h2 +[3] vmudn $v06, $v28, $v08.h0 +[4] vmadh $v05, $v27, $v08.h0 +[5] vmadn $v06, $v26, $v08.h1 +[^] luv $v03, 0, 0, $at +[6] vmadh $v05, $v25, $v08.h1 +[7] vmadn $v06, $v24, $v08.h2 +[8] vmadh $v05, $v23, $v08.h2 +[9] vmadn $v06, $v22, $v08.h3 +[^] luv $v04, 0, 16, $s4 +[10] vmadh $v05, $v21, $v08.h3 +[^] lpv $v08, 0, 8, $s3 +[11] beq $s3, $s2, LABEL_0003 # unlikely +[12] nop +[13] luv $v01, 0, 0, $s3 +[^*] vmulf $v02, $v07, $v08.v +[15] addiu $s3, $s3, 16 +[16] lpv $v08, 0, 8, $s3 +[^**] vmulu $v29, $v01, $v02.h0 +[19] vmacu $v29, $v01, $v02.h1 +[20] vmacu $v29, $v01, $v02.h2 +[^] bne $s3, $s2, LABEL_0004 # unlikely +[21***] vadd $v03, $v03, $v29.v +[25] vmudl $v29, $v00, $v06.h3 +[26] vmadm $v29, $v15, $v06.h3 +[27] vmadn $v02, $v00, $v05.h3 +[^] lqv $v08, 0, 32, $s4 +[28] vmadh $v01, $v15, $v05.h3 +[29] vch $v29, $v05, $v05.h3 +[30] vcl $v29, $v06, $v06.h3 +[31] cfc2 $t6, $vcc +[32] addiu $s1, $s1, 72 +[ ^] vch $v29, $v05, $v01 +[33] vcl $v29, $v06, $v02 +[34] vmulf $v04, $v04, $v03.v +[ ^] cfc2 $t5, $vcc +[35] vmudl $v06, $v06, $v10.v +[ ^] andi $t8, $t6, 1799 +[36] vmadm $v05, $v05, $v10.v +[37] vmadn $v06, $v00, $v00 +[ ^] srl $t9, $t5, 4 +[38] andi $k0, $t5, 1799 +[39] srl $t4, $k0, 5 +[40] sdv $v05, 8, 16, $s5 +[41] sdv $v05, 0, 16, $s6 +[ ^] vrcph $v05.e3, $v05.e3 +[42] sdv $v06, 0, 24, $s6 +[43] sdv $v06, 8, 24, $s5 +[ ^] vrcpl $v06.e3, $v06.e3 +[44] andi $t9, $t9, 1799 +[45] or $k0, $k0, $t4 +[46] srl $t4, $t9, 5 +[ ^] vrcph $v05.e3, $v05.e7 +[47] vrcpl $v06.e7, $v06.e7 +[ ^] or $t9, $t9, $t4 +[48] srl $t4, $t8, 5 +[ ^] vrcph $v05.e7, $v00.e7 +[49] nor $t8, $t8, $t4 +[50] srl $t7, $t6, 4 +[ ^*] vaddc $v03, $v06, $v11.e1 +[ 52] vadd $v02, $v05, $v11.e0 +[ ^] ssv $v05, 6, 32, $s6 +[ 53] andi $t8, $t8, 255 +[ 54] suv $v04, 0, 8, $s6 +[ 55] ssv $v05, 14, 32, $s5 +[ ^] vmudn $v03, $v03, $v11.e3 +[ 56] ldv $v03, 0, 24, $s4 +[ ^] vmadh $v02, $v02, $v11.e3 +[ 57] ssv $v06, 14, 34, $s5 +[ 58] addiu $s4, $s4, 32 +[ 59] ssv $v06, 6, 34, $s6 +[ 60] andi $t7, $t7, 1799 +[ ^] vsub $v02, $v11, $v02.v +[ 61] sll $k0, $k0, 8 +[ ^] vmudl $v29, $v06, $v06.h3 +[ 62] srl $t4, $t7, 5 +[ ^] vmadm $v29, $v05, $v06.h3 +[ 63] vmadn $v06, $v06, $v05.h3 +[ ^] nor $t7, $t7, $t4 +[ 64] vmadh $v05, $v05, $v05.h3 +[ ^] mfc2 $sp, $v02.e6 +[ 65] mfc2 $fp, $v02.e2 +[^**] vmudl $v29, $v06, $v14.v +[ 68] vmadm $v29, $v05, $v14.v +[ ^] sra $sp, $sp, 7 +[ 69] sra $fp, $fp, 7 +[ ^*] vmadn $v06, $v06, $v13.v +[ 71] vmadh $v05, $v05, $v13.v +[ 72] vmadh $v05, $v12, $v30.e7 +[ ^] suv $v04, 4, 8, $s5 +[ 73] vor $v02, $v00, $v07 +[ ^] sb $fp, -69($s1) +[ 74] vand $v07, $v17, $v08.h3 +[ ^] or $k0, $k0, $t8 +[ 75] sb $sp, -33($s1) +[ 76] sdv $v05, 0, 0, $s6 +[ 77] sdv $v05, 8, 0, $s5 +[ 78] sh $k0, 6($s6) +[ ^] vmudn $v07, $v07, $v16.v +[ 79] sb $t9, 6($s5) +[ ^] vmov $v08.e3, $v30.e7 +[ 80] vmov $v08.e7, $v30.e7 +[ ^] jal $k1 +[81*] sb $t7, 7($s5) +[ 83] slv $v03, 4, 12, $s5 +[ 84] slv $v03, 0, 12, $s6 +[ 85] addiu $s6, $s6, 72 +)"; + +TEST_CASE("Eval - Cost (Examples) - T3D Vertex Loop - 0", "[evalCostExample]") { + auto lines = textToAsmLines(T3D_CODE); + auto cyclesExp = textToAsmCycle(T3D_CODE); + + auto cycles = linesToCycles(lines); + + REQUIRE(cycles.size() == cyclesExp.size()); + for (size_t line = 0; line < cycles.size(); ++line) { + INFO("Line " << line); + REQUIRE(cycles[line] == cyclesExp[line]); + } +} diff --git a/cpp/tests/test_examples.cpp b/cpp/tests/test_examples.cpp new file mode 100644 index 0000000..91a6c6f --- /dev/null +++ b/cpp/tests/test_examples.cpp @@ -0,0 +1,295 @@ +#include +#include "diff_util.h" +#include "pipeline.h" + +#include +#include +#include + +static std::string readFile(const std::string &path) { + std::ifstream f(path); + REQUIRE(f.is_open()); + std::ostringstream ss; + ss << f.rdbuf(); + return ss.str(); +} + +static std::string examplesPath(const std::string &rel) { + return "src/tests/examples/" + rel; +} + +// Transpile RSPL source, expect no errors and return the ASM. +static std::string transpile(const std::string &src, bool optimize, + bool debugInfo = false, + const std::string &sourceDir = ".") { + rspl::TranspileConfig cfg; + cfg.rspqWrapper = true; + cfg.optimize = optimize; + cfg.debugInfo = debugInfo; + cfg.sourceDir = sourceDir; + auto res = rspl::transpileSource(src, cfg); + REQUIRE(res.warn.empty()); + return res.asm_; +} + +static std::string transpileFile(const std::string &path, bool optimize, + bool debugInfo = false) { + // Set sourceDir to the directory of the file so includes resolve correctly + auto lastSlash = path.rfind('/'); + std::string dir = (lastSlash != std::string::npos) ? path.substr(0, lastSlash) : "."; + return transpile(readFile(path), optimize, debugInfo, dir); +} + +TEST_CASE("Examples - Squares 2D", "[examples]") { + auto code = readFile(examplesPath("squares2d.rspl")); + REQUIRE_NOTHROW(transpile(code, false)); +} + +TEST_CASE("Examples - 3D", "[examples]") { + auto path = examplesPath("3d.rspl"); + auto expected = readFile(examplesPath("3d.S")); + auto asm_ = transpileFile(path, true); + REQUIRE_ASM_EQ(expected, asm_); +} + +TEST_CASE("Examples - Tiny3D", "[examples]") { + auto path = examplesPath("t3d/rsp_tiny3d.rspl"); + auto expected = readFile(examplesPath("t3d/rsp_tiny3d.S")); + auto asm_ = transpileFile(path, true, true); + REQUIRE_ASM_EQ(expected, asm_); +} + +TEST_CASE("Examples - TinyPX", "[examples]") { + auto path = examplesPath("t3d/rsp_tinypx.rspl"); + auto expected = readFile(examplesPath("t3d/rsp_tinypx.S")); + auto asm_ = transpileFile(path, true, true); + REQUIRE_ASM_EQ(expected, asm_); +} + +TEST_CASE("Examples - HDR/Bloom", "[examples]") { + auto path = examplesPath("rsp_fx.rspl"); + REQUIRE_NOTHROW(transpileFile(path, true, true)); +} + +TEST_CASE("Examples - Mandelbrot", "[examples]") { + auto code = readFile(examplesPath("mandelbrot.rspl")); + REQUIRE_NOTHROW(transpile(code, false)); +} + +TEST_CASE("Examples - Matrix x Vector", "[examples]") { + auto src = R"( +include "rsp_queue.inc" +state { + vec32 VEC_SLOTS[20]; +} + +command<0> VecCmd_Transform(u32 vec_out, u32 mat_in) +{ + u32<$t0> trans_mtx = mat_in >> 16; + trans_mtx &= 0xFF0; + + u32<$t1> trans_vec = mat_in & 0xFF0; + u32<$t2> trans_out = vec_out & 0xFF0; + + trans_mtx += VEC_SLOTS; + trans_vec += VEC_SLOTS; + trans_out += VEC_SLOTS; + + vec32<$v01> mat0 = load(trans_mtx, 0).xyzwxyzw; + vec32<$v03> mat1 = load(trans_mtx, 8).xyzwxyzw; + vec32<$v05> mat2 = load(trans_mtx, 0x20).xyzwxyzw; + vec32<$v07> mat3 = load(trans_mtx, 0x28).xyzwxyzw; + + vec32<$v09> vecIn = load(trans_vec); + vec32<$v13> res; + + res = mat0 * vecIn.xxxxXXXX; + res = mat1 +* vecIn.yyyyYYYY; + res = mat2 +* vecIn.zzzzZZZZ; + res = mat3 +* vecIn.wwwwWWWW; + + store(res, trans_out); +} + +include "rsp_rdpq.inc" +)"; + auto asm_ = transpile(src, false); + // Trim trailing whitespace to match JS test's .trimEnd() + while (!asm_.empty() && (asm_.back() == '\n' || asm_.back() == ' ')) + asm_.pop_back(); + // Golden output from JS test + REQUIRE_ASM_EQ(R"(## Auto-generated file, transpiled with RSPL +#include + +.set noreorder +.set noat +.set nomacro + +#undef zero +#undef at +#undef v0 +#undef v1 +#undef a0 +#undef a1 +#undef a2 +#undef a3 +#undef t0 +#undef t1 +#undef t2 +#undef t3 +#undef t4 +#undef t5 +#undef t6 +#undef t7 +#undef s0 +#undef s1 +#undef s2 +#undef s3 +#undef s4 +#undef s5 +#undef s6 +#undef s7 +#undef t8 +#undef t9 +#undef k0 +#undef k1 +#undef gp +#undef sp +#undef fp +#undef ra +.equ hex.$zero, 0 +.equ hex.$at, 1 +.equ hex.$v0, 2 +.equ hex.$v1, 3 +.equ hex.$a0, 4 +.equ hex.$a1, 5 +.equ hex.$a2, 6 +.equ hex.$a3, 7 +.equ hex.$t0, 8 +.equ hex.$t1, 9 +.equ hex.$t2, 10 +.equ hex.$t3, 11 +.equ hex.$t4, 12 +.equ hex.$t5, 13 +.equ hex.$t6, 14 +.equ hex.$t7, 15 +.equ hex.$s0, 16 +.equ hex.$s1, 17 +.equ hex.$s2, 18 +.equ hex.$s3, 19 +.equ hex.$s4, 20 +.equ hex.$s5, 21 +.equ hex.$s6, 22 +.equ hex.$s7, 23 +.equ hex.$t8, 24 +.equ hex.$t9, 25 +.equ hex.$k0, 26 +.equ hex.$k1, 27 +.equ hex.$gp, 28 +.equ hex.$sp, 29 +.equ hex.$fp, 30 +.equ hex.$ra, 31 +#define vco 0 +#define vcc 1 +#define vce 2 + +.data + RSPQ_BeginOverlayHeader + RSPQ_DefineCommand VecCmd_Transform, 8 + RSPQ_EndOverlayHeader + + RSPQ_BeginSavedState + STATE_MEM_START: + .align 4 + VEC_SLOTS: .ds.b 640 + STATE_MEM_END: + RSPQ_EndSavedState + +.text +OVERLAY_CODE_START: + +VecCmd_Transform: + srl $t0, $a1, 16 + andi $t0, $t0, 4080 + andi $t1, $a1, 4080 + andi $t2, $a0, 4080 + addiu $t0, $t0, %lo(VEC_SLOTS) + addiu $t1, $t1, %lo(VEC_SLOTS) + addiu $t2, $t2, %lo(VEC_SLOTS) + ldv $v01, 0, 0, $t0 + ldv $v01, 8, 0, $t0 + ldv $v02, 0, 8, $t0 + ldv $v02, 8, 8, $t0 + ldv $v03, 0, 8, $t0 + ldv $v03, 8, 8, $t0 + ldv $v04, 0, 16, $t0 + ldv $v04, 8, 16, $t0 + ldv $v05, 0, 32, $t0 + ldv $v05, 8, 32, $t0 + ldv $v06, 0, 40, $t0 + ldv $v06, 8, 40, $t0 + ldv $v07, 0, 40, $t0 + ldv $v07, 8, 40, $t0 + ldv $v08, 0, 48, $t0 + ldv $v08, 8, 48, $t0 + lqv $v09, 0, 0, $t1 + lqv $v10, 0, 16, $t1 + vmudl $v29, $v02, $v10.h0 + vmadm $v29, $v01, $v10.h0 + vmadn $v14, $v02, $v09.h0 + vmadh $v13, $v01, $v09.h0 + vmadl $v29, $v04, $v10.h1 + vmadm $v29, $v03, $v10.h1 + vmadn $v14, $v04, $v09.h1 + vmadh $v13, $v03, $v09.h1 + vmadl $v29, $v06, $v10.h2 + vmadm $v29, $v05, $v10.h2 + vmadn $v14, $v06, $v09.h2 + vmadh $v13, $v05, $v09.h2 + vmadl $v29, $v08, $v10.h3 + vmadm $v29, $v07, $v10.h3 + vmadn $v14, $v08, $v09.h3 + vmadh $v13, $v07, $v09.h3 + sqv $v13, 0, 0, $t2 + sqv $v14, 0, 16, $t2 + j RSPQ_Loop + nop + +OVERLAY_CODE_END: + +#define zero $0 +#define v0 $2 +#define v1 $3 +#define a0 $4 +#define a1 $5 +#define a2 $6 +#define a3 $7 +#define t0 $8 +#define t1 $9 +#define t2 $10 +#define t3 $11 +#define t4 $12 +#define t5 $13 +#define t6 $14 +#define t7 $15 +#define s0 $16 +#define s1 $17 +#define s2 $18 +#define s3 $19 +#define s4 $20 +#define s5 $21 +#define s6 $22 +#define s7 $23 +#define t8 $24 +#define t9 $25 +#define k0 $26 +#define k1 $27 +#define gp $28 +#define sp $29 +#define fp $30 +#define ra $31 + +.set at +.set macro +#include )", asm_); +} diff --git a/cpp/tests/test_immediateScalar.cpp b/cpp/tests/test_immediateScalar.cpp new file mode 100644 index 0000000..3bfd795 --- /dev/null +++ b/cpp/tests/test_immediateScalar.cpp @@ -0,0 +1,92 @@ +#include +#include "pipeline.h" + +static const auto CONF = rspl::TranspileConfig{.rspqWrapper = false}; + +TEST_CASE("Immediate Scalar - Unsigned", "[immediateScalar]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test() +{ + u32<$t0> c = TEST_CONST; + LOAD_A: c = 0; + LOAD_B: c = 0xFF; + LOAD_C: c = 0xFFFF; + LOAD_D: c = 0x7FFF; + LOAD_E: c = 0x8000; + LOAD_F: c = 0xFF120000; + LOAD_G: c = 0xFFFF7FFF; + LOAD_H: c = 0xFFFF8000; + LOAD_I: c = 0xFFFFF; + LOAD_J: c = 0xFFFFFFFF; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ori $t0, $zero, %lo(TEST_CONST) + LOAD_A: + or $t0, $zero, $zero + LOAD_B: + addiu $t0, $zero, 255 + LOAD_C: + ori $t0, $zero, 0xFFFF + LOAD_D: + addiu $t0, $zero, 32767 + LOAD_E: + ori $t0, $zero, 0x8000 + LOAD_F: + lui $t0, 0xFF12 + LOAD_G: + lui $t0, 0xFFFF + ori $t0, $t0, 0x7FFF + LOAD_H: + addiu $t0, $zero, -32768 + LOAD_I: + lui $t0, 0x0F + ori $t0, $t0, 0xFFFF + LOAD_J: + addiu $t0, $zero, -1 + jr $ra + nop)"); +} + +TEST_CASE("Immediate Scalar - Signed", "[immediateScalar]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test() +{ + s32<$t0> c = TEST_CONST; + LOAD_A: c = 0xFF; + LOAD_B: c = 0xFFFF; + LOAD_C: c = 0xFFFFF; + LOAD_D: c = -255; + LOAD_E: c = -65535; + LOAD_F: c = -1048575; + LOAD_G: c = -2147483648; + LOAD_H: c = 2147483647; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ori $t0, $zero, %lo(TEST_CONST) + LOAD_A: + addiu $t0, $zero, 255 + LOAD_B: + ori $t0, $zero, 0xFFFF + LOAD_C: + lui $t0, 0x0F + ori $t0, $t0, 0xFFFF + LOAD_D: + addiu $t0, $zero, -255 + LOAD_E: + lui $t0, 0xFFFF + ori $t0, $t0, 0x01 + LOAD_F: + lui $t0, 0xFFF0 + ori $t0, $t0, 0x01 + LOAD_G: + lui $t0, 0x8000 + LOAD_H: + lui $t0, 0x7FFF + ori $t0, $t0, 0xFFFF + jr $ra + nop)"); +} diff --git a/cpp/tests/test_labels.cpp b/cpp/tests/test_labels.cpp new file mode 100644 index 0000000..43c5a3b --- /dev/null +++ b/cpp/tests/test_labels.cpp @@ -0,0 +1,65 @@ +#include +#include "pipeline.h" + +TEST_CASE("Labels - Basic", "[labels]") { + auto result = rspl::transpileSource( + R"( +function test_label() +{ + label_a: + label_b: label_c: +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_label: + label_a: + label_b: + label_c: + jr $ra + nop)"); +} + +TEST_CASE("Labels - With instructions", "[labels]") { + auto result = rspl::transpileSource( + R"( +function test_label() +{ + u32<$t0> a; + label_a: + a += 1; + goto label_b; + label_b: + a += 2; + goto label_a; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_label: + label_a: + addiu $t0, $t0, 1 + j label_b + nop + label_b: + addiu $t0, $t0, 2 + j label_a + nop + jr $ra + nop)"); +} + +TEST_CASE("Labels - Label used as value resolves to %lo", "[labels]") { + // Forward label references in expressions must produce %lo(LABEL) + auto res = rspl::transpileSource(R"( +function test() +{ + u32 x; + x = MY_TARGET; + MY_TARGET: +} +)", + {.rspqWrapper = false}); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_.find("%lo(MY_TARGET)") != std::string::npos); +} diff --git a/cpp/tests/test_load.cpp b/cpp/tests/test_load.cpp new file mode 100644 index 0000000..3a9ab86 --- /dev/null +++ b/cpp/tests/test_load.cpp @@ -0,0 +1,332 @@ +#include +#include "pipeline.h" +#include + +static const auto CONF = rspl::TranspileConfig{.rspqWrapper = false}; + +TEST_CASE("Load - Scalar 32-Bit", "[load]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_scalar_load() +{ + u32<$t0> src, dst; + + dst = load(src); + dst = load(src, 0x10); + dst = load(src, TEST_CONST); + + dst = load(TEST_CONST); + dst = load(TEST_CONST, 0x10); + // dst = load(TEST_CONST, TEST_CONST); Invalid +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_load: + lw $t1, 0($t0) + lw $t1, 16($t0) + lw $t1, %lo(TEST_CONST)($t0) + lw $t1, %lo(TEST_CONST + 0) + lw $t1, %lo(TEST_CONST + 16) + jr $ra + nop)"); +} + +TEST_CASE("Load - Scalar Cast", "[load]") { + auto result = rspl::transpileSource(R"( +function test_scalar_load() +{ + u32<$t0> src, dst; + + dst:u32 = load(src, 0x10); + dst:u16 = load(src, 0x10); + dst:u8 = load(src, 0x10); + + dst:s32 = load(src, 0x10); + dst:s16 = load(src, 0x10); + dst:s8 = load(src, 0x10); +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_load: + lw $t1, 16($t0) + lhu $t1, 16($t0) + lbu $t1, 16($t0) + lw $t1, 16($t0) + lh $t1, 16($t0) + lb $t1, 16($t0) + jr $ra + nop)"); +} + +TEST_CASE("Load - Invalid vector load (const not % 16)", "[load]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(function test() { + u32<$t0> a; + vec16 err = load(a, 5); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(function test() { u32<$t0> a; vec16 err = load(a, 5); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Invalid full vector-load offset, must be a multiple of 16") != std::string::npos); + } +} + +TEST_CASE("Load - Invalid vector load (vector as addr)", "[load]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(state { u32 TEST_CONST; } +function test() { + vec32<$v01> a; + a = load(a, TEST_CONST); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(state { u32 TEST_CONST; } function test() { vec32<$v01> a; a = load(a, TEST_CONST); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("load() requires first argument to be a scalar") != std::string::npos); + } +} + +TEST_CASE("Load - Vector 32-Bit", "[load]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_load() +{ + u32<$t0> src; + vec32<$v01> dst; + + WholeVector: + dst = load(src); + dst = load(src, 0x10); + dst.y = load(src); + dst.z = load(src, 0x10); + //dst = load(src, TEST_CONST); Invalid + //dst = load(TEST_CONST); Invalid + //dst = load(TEST_CONST, 0x10); Invalid + + Swizzle: + dst = load(src).xyzwxyzw; + dst = load(src, 0x10).xyzwxyzw; + dst.y = load(src).xyzwxyzw; + dst.z = load(src, 0x10).xyzwxyzw; + //dst = load(src, TEST_CONST).xyzwxyzw; Invalid + //dst = load(TEST_CONST).xyzwxyzw; Invalid + //dst = load(TEST_CONST, 0x10).xyzwxyzw; Invalid +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_load: + WholeVector: + lqv $v01, 0, 0, $t0 + lqv $v02, 0, 16, $t0 + lqv $v01, 0, 16, $t0 + lqv $v02, 0, 32, $t0 + lqv $v01, 2, 0, $t0 + lqv $v02, 2, 16, $t0 + lqv $v01, 4, 16, $t0 + lqv $v02, 4, 32, $t0 + Swizzle: + ldv $v01, 0, 0, $t0 + ldv $v01, 8, 0, $t0 + ldv $v02, 0, 8, $t0 + ldv $v02, 8, 8, $t0 + ldv $v01, 0, 16, $t0 + ldv $v01, 8, 16, $t0 + ldv $v02, 0, 24, $t0 + ldv $v02, 8, 24, $t0 + ldv $v01, 2, 0, $t0 + ldv $v01, 10, 0, $t0 + ldv $v02, 2, 8, $t0 + ldv $v02, 10, 8, $t0 + ldv $v01, 4, 16, $t0 + ldv $v01, 12, 16, $t0 + ldv $v02, 4, 24, $t0 + ldv $v02, 12, 24, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Load - Vector 32-Bit Split", "[load]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> src; + vec32<$v01> dst; + + LeftSide: + dst.xyzw = load(src, 0x00).xyzw; + dst.xyzw = load(src, 0x10).xyzw; + + RightSide: + dst.XYZW = load(src, 0x00).XYZW; + dst.XYZW = load(src, 0x10).XYZW; +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + LeftSide: + ldv $v01, 0, 0, $t0 + ldv $v02, 0, 8, $t0 + ldv $v01, 0, 16, $t0 + ldv $v02, 0, 24, $t0 + RightSide: + ldv $v01, 8, 8, $t0 + ldv $v02, 8, 16, $t0 + ldv $v01, 8, 24, $t0 + ldv $v02, 8, 32, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Load - Vector 32-Bit Unaligned", "[load]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_load() +{ + u32<$t0> src; + vec32<$v01> dst; + + WholeVector: + dst = load_unaligned(src); + dst = load_unaligned(src, 0x10); + dst.y = load_unaligned(src); + dst.z = load_unaligned(src, 0x10); + + Swizzle: + dst = load_unaligned(src).xyzwxyzw; + dst = load_unaligned(src, 0x10).xyzwxyzw; + dst.y = load_unaligned(src).xyzwxyzw; + dst.z = load_unaligned(src, 0x10).xyzwxyzw; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_load: + WholeVector: + lqv $v01, 0, 0, $t0 + lrv $v01, 0, 16, $t0 + lqv $v02, 0, 16, $t0 + lrv $v02, 0, 32, $t0 + lqv $v01, 0, 16, $t0 + lrv $v01, 0, 32, $t0 + lqv $v02, 0, 32, $t0 + lrv $v02, 0, 48, $t0 + lqv $v01, 2, 0, $t0 + lrv $v01, 2, 16, $t0 + lqv $v02, 2, 16, $t0 + lrv $v02, 2, 32, $t0 + lqv $v01, 4, 16, $t0 + lrv $v01, 4, 32, $t0 + lqv $v02, 4, 32, $t0 + lrv $v02, 4, 48, $t0 + Swizzle: + ldv $v01, 0, 0, $t0 + ldv $v01, 8, 0, $t0 + ldv $v02, 0, 8, $t0 + ldv $v02, 8, 8, $t0 + ldv $v01, 0, 16, $t0 + ldv $v01, 8, 16, $t0 + ldv $v02, 0, 24, $t0 + ldv $v02, 8, 24, $t0 + ldv $v01, 2, 0, $t0 + ldv $v01, 10, 0, $t0 + ldv $v02, 2, 8, $t0 + ldv $v02, 10, 8, $t0 + ldv $v01, 4, 16, $t0 + ldv $v01, 12, 16, $t0 + ldv $v02, 4, 24, $t0 + ldv $v02, 12, 24, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Load - Vector Cast", "[load]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> addr; + vec32<$v01> dst; + + dst:sint = load(addr, 0x10); + dst:ufract = load(addr, 0x10); + + dst:sint = load(addr, 0x10).XY; + dst:ufract = load(addr, 0x10).XY; + + dst:sint.z = load(addr, 0x10).XY; + dst:ufract.z = load(addr, 0x10).XY; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + lqv $v01, 0, 16, $t0 + lqv $v02, 0, 16, $t0 + llv $v01, 0, 24, $t0 + llv $v02, 0, 24, $t0 + llv $v01, 4, 24, $t0 + llv $v02, 4, 24, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Load - Vector Packed", "[load]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> src; + vec16<$v01> dst; + + Unsigned: + dst = load_vec_u8(src, 0x00); + dst.z = load_vec_u8(src, 0x10); + + Signed: + dst.x = load_vec_s8(src, 0x00); + dst.z = load_vec_s8(src, 0x10); +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + Unsigned: + luv $v01, 0, 0, $t0 + luv $v01, 2, 16, $t0 + Signed: + lpv $v01, 0, 0, $t0 + lpv $v01, 2, 16, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Load - Vector Transposed", "[load]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> ptr; + vec16<$v08> a; + vec16<$v16> b; + + a = load_transposed(0, ptr, 0x00); + a = load_transposed(0, ptr); + a = load_transposed(1, ptr, 0x10); + b = load_transposed(4, ptr, 0x20); + b = load_transposed(7, ptr, 0x30); + END: +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ltv $v08, 0, 0, $t0 + ltv $v08, 0, 0, $t0 + ltv $v08, 2, 16, $t0 + ltv $v16, 8, 32, $t0 + ltv $v16, 14, 48, $t0 + END: + jr $ra + nop)"); +} + +TEST_CASE("Load - Invalid Transposed reg", "[load]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(function test() { + u32<$t0> ptr; + vec32<$v04> v; + v = load_transposed(0, ptr, 0x00); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(function test() { u32<$t0> ptr; vec32<$v04> v; v = load_transposed(0, ptr, 0x00); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("load_transposed() requires result register to be $v00, $v08, $v16 or $v24") != std::string::npos); + } +} + +TEST_CASE("Load - Invalid Transposed offset", "[load]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(function test() { + u32<$t0> ptr; + vec32<$v16> v; + v = load_transposed(0, ptr, 0x04); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(function test() { u32<$t0> ptr; vec32<$v16> v; v = load_transposed(0, ptr, 0x04); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("load_transposed() requires offset to be multiple of 16") != std::string::npos); + } +} diff --git a/cpp/tests/test_loop.cpp b/cpp/tests/test_loop.cpp new file mode 100644 index 0000000..a10d68a --- /dev/null +++ b/cpp/tests/test_loop.cpp @@ -0,0 +1,302 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Loops - Basic While-Loop", "[loops]") { + auto result = rspl::transpileSource( + R"(function test() + { + u32<$t0> i=0; + while(i<10) { + i+=1; + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + LABEL_test_0001: + sltiu $at, $t0, 10 + beq $at, $zero, LABEL_test_0002 + nop + addiu $t0, $t0, 1 + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Nested While-Loop", "[loops]") { + auto result = rspl::transpileSource( + R"(function test() + { + u32<$t0> i=0; + u32<$t1> j=0; + + while(i<10) { + while(j<20) { + j+=1; + } + i+=1; + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + or $t1, $zero, $zero + LABEL_test_0001: + sltiu $at, $t0, 10 + beq $at, $zero, LABEL_test_0002 + nop + LABEL_test_0003: + sltiu $at, $t1, 20 + beq $at, $zero, LABEL_test_0004 + nop + addiu $t1, $t1, 1 + j LABEL_test_0003 + nop + LABEL_test_0004: + addiu $t0, $t0, 1 + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - While Loop - Break", "[loops]") { + auto result = rspl::transpileSource( + R"(function test() + { + u32<$t0> i=0; + while(i<10) { + break; + i+=1; + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + LABEL_test_0001: + sltiu $at, $t0, 10 + beq $at, $zero, LABEL_test_0002 + nop + j LABEL_test_0002 + nop + addiu $t0, $t0, 1 + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - While Loop - scoped Break", "[loops]") { + auto result = rspl::transpileSource( + R"(function test() + { + u32<$t0> i=0; + while(i<10) { + if(!i)break; + i+=1; + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + LABEL_test_0001: + sltiu $at, $t0, 10 + beq $at, $zero, LABEL_test_0002 + nop + bne $t0, $zero, LABEL_test_0003 + nop + j LABEL_test_0002 + nop + LABEL_test_0003: + addiu $t0, $t0, 1 + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - While Loop - Continue", "[loops]") { + auto result = rspl::transpileSource( + R"(function test() + { + u32<$t0> i=0; + while(i<10) { + continue; + i+=1; + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + LABEL_test_0001: + sltiu $at, $t0, 10 + beq $at, $zero, LABEL_test_0002 + nop + j LABEL_test_0001 + nop + addiu $t0, $t0, 1 + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop !=", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 0; + u32 b = 10; + loop { + a += 1; + } while(a != b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + addiu $t1, $zero, 10 + LABEL_test_0001: + addiu $t0, $t0, 1 + bne $t0, $t1, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop ==", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 0; + u32 b = 10; + loop { + a += 1; + } while(a == b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + addiu $t1, $zero, 10 + LABEL_test_0001: + addiu $t0, $t0, 1 + beq $t0, $t1, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop <", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 0; + u32 b = 10; + loop { + a += 1; + } while(a < b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + addiu $t1, $zero, 10 + LABEL_test_0001: + addiu $t0, $t0, 1 + sltu $at, $t0, $t1 + bne $at, $zero, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop >", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 10; + u32 b = 0; + loop { + a -= 1; + } while(a > b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $t0, $zero, 10 + or $t1, $zero, $zero + LABEL_test_0001: + addiu $t0, $t0, 65535 + sltu $at, $t1, $t0 + bne $at, $zero, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop <=", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 10; + u32 b = 0; + loop { + a -= 1; + } while(a <= b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $t0, $zero, 10 + or $t1, $zero, $zero + LABEL_test_0001: + addiu $t0, $t0, 65535 + sltu $at, $t1, $t0 + beq $at, $zero, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Loops - Do-While-Loop >=", "[loops]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32 a = 0; + u32 b = 10; + loop { + a += 1; + } while(a >= b) +})", + {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t0, $zero, $zero + addiu $t1, $zero, 10 + LABEL_test_0001: + addiu $t0, $t0, 1 + sltu $at, $t0, $t1 + beq $at, $zero, LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} \ No newline at end of file diff --git a/cpp/tests/test_macros.cpp b/cpp/tests/test_macros.cpp new file mode 100644 index 0000000..79af87a --- /dev/null +++ b/cpp/tests/test_macros.cpp @@ -0,0 +1,99 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Macros - Basic replacement", "[macros]") { + auto result = rspl::transpileSource( + R"( + macro test(u32 add) { + add += 42; + } + + function test_macro() { + u32<$t2> a; + u32<$s3> b; + test(a); + + if(a < 3) { + test(a); + } + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_macro: + addiu $t2, $t2, 42 + sltiu $at, $t2, 3 + beq $at, $zero, LABEL_test_macro_0001 + nop + addiu $t2, $t2, 42 + LABEL_test_macro_0001: + jr $ra + nop)"); +} + +TEST_CASE("Macros - Nested macro", "[macros]") { + auto result = rspl::transpileSource( + R"( + macro test_b(u32 argB) { + argB += 42; + } + + macro test_a(u32 argA) { + test_b(argA); + } + + function test_macro() { + u32<$t2> a; + test_a(a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_macro: + addiu $t2, $t2, 42 + jr $ra + nop)"); +} + +TEST_CASE("Macros - Scope local", "[macros]") { + auto result = rspl::transpileSource( + R"( + macro test_b(u32 argB) { + argB += 42; + } + + function test_macro() { + u32<$t2> a; + u32<$t3> argB; + test_b(a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_macro: + addiu $t2, $t2, 42 + jr $ra + nop)"); +} + +TEST_CASE("Macros - Return Value", "[macros]") { + auto result = rspl::transpileSource( + R"( + macro test_a(u32 res, u32 argA, u32 argB) { + res = argA + argB; + } + + function test_macro() { + u32<$a0> argA, argB; + u32<$s0> a = test_a(argA, argB); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_macro: + addu $s0, $a0, $a1 + jr $ra + nop)"); +} diff --git a/cpp/tests/test_optAssert.cpp b/cpp/tests/test_optAssert.cpp new file mode 100644 index 0000000..0bd9b9e --- /dev/null +++ b/cpp/tests/test_optAssert.cpp @@ -0,0 +1,178 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult transpile(const std::string &src, bool optimize) { + return rspl::transpileSource(src, + {.rspqWrapper = false, .optimize = optimize}); +} + +TEST_CASE("Optimizer E2E - Assertion - Assert variations (unopt)", "[optAssert]") { + auto res = transpile(R"(function a() +{ + u32 buff,test; + TEST_A: + if(buff > 4)assert(0xAB); + + TEST_B: + if(buff < 4)assert(0xAB); + + TEST_C: + if(buff == 4)assert(0xAB); + + TEST_D: + if(buff != 4)assert(0xAB); + + TEST_E: + if(buff == 0)assert(0xAB); + + TEST_F: + if(buff != 0)assert(0xAB); + + TEST_G: + if(buff != test)assert(0xAB); + + TEST_H: + if(buff < test)assert(0xAB); + + TEST_I: + assert(0xAB); +})", + false); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(a: + TEST_A: + sltiu $at, $t0, 5 + bne $at, $zero, LABEL_a_0001 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0001: + TEST_B: + sltiu $at, $t0, 4 + beq $at, $zero, LABEL_a_0002 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0002: + TEST_C: + addiu $at, $zero, 4 + bne $t0, $at, LABEL_a_0003 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0003: + TEST_D: + addiu $at, $zero, 4 + beq $t0, $at, LABEL_a_0004 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0004: + TEST_E: + bne $t0, $zero, LABEL_a_0005 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0005: + TEST_F: + beq $t0, $zero, LABEL_a_0006 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0006: + TEST_G: + beq $t0, $t1, LABEL_a_0007 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0007: + TEST_H: + sltu $at, $t0, $t1 + beq $at, $zero, LABEL_a_0008 + nop + lui $at, 171 + j assertion_failed + nop + LABEL_a_0008: + TEST_I: + lui $at, 171 + j assertion_failed + nop + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Assertion - Assert variations (opt)", "[optAssert]") { + auto res = transpile(R"(function a() +{ + u32 buff,test; + TEST_A: + if(buff > 4)assert(0xAB); + + TEST_B: + if(buff < 4)assert(0xAB); + + TEST_C: + if(buff == 4)assert(0xAB); + + TEST_D: + if(buff != 4)assert(0xAB); + + TEST_E: + if(buff == 0)assert(0xAB); + + TEST_F: + if(buff != 0)assert(0xAB); + + TEST_G: + if(buff != test)assert(0xAB); + + TEST_H: + if(buff < test)assert(0xAB); + + TEST_I: + assert(0xAB); +})", + true); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(a: + TEST_A: + sltiu $at, $t0, 5 + beq $at, $zero, assertion_failed + lui $at, 171 + TEST_B: + sltiu $at, $t0, 4 + bne $at, $zero, assertion_failed + lui $at, 171 + TEST_C: + addiu $at, $zero, 4 + beq $t0, $at, assertion_failed + lui $at, 171 + TEST_D: + addiu $at, $zero, 4 + bne $t0, $at, assertion_failed + lui $at, 171 + TEST_E: + beq $t0, $zero, assertion_failed + lui $at, 171 + TEST_F: + bne $t0, $zero, assertion_failed + lui $at, 171 + TEST_G: + bne $t0, $t1, assertion_failed + lui $at, 171 + TEST_H: + sltu $at, $t0, $t1 + bne $at, $zero, assertion_failed + lui $at, 171 + TEST_I: + j assertion_failed + lui $at, 171)"); +} diff --git a/cpp/tests/test_optBranchJump.cpp b/cpp/tests/test_optBranchJump.cpp new file mode 100644 index 0000000..bf97b9d --- /dev/null +++ b/cpp/tests/test_optBranchJump.cpp @@ -0,0 +1,99 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Branch-Jump - Branch + Goto", "[optBranchJump]") { + auto res = optTranspile(R"(function test() +{ + u32<$t0> a; + LABEL_A: + if(a != 0)goto LABEL_A; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_A: + bne $t0, $zero, LABEL_A + nop + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Branch-Jump - Branch + Goto (no opt)", "[optBranchJump]") { + auto res = optTranspile(R"(function test() +{ + u32<$t0> a; + LABEL_A: + if(a != 0) { + a += 1; + goto LABEL_A; + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_A: + beq $t0, $zero, LABEL_test_0001 + nop + j LABEL_A + addiu $t0, $t0, 1 + LABEL_test_0001: + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Branch-Jump - Loop - Used Label", "[optBranchJump]") { + auto res = optTranspile(R"(function test() +{ + u32<$t0> a; + loop { + if(a == 1)continue; + SOME_LABEL: + + if(a == 0)goto SOME_LABEL; + LOOP_END: + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_test_0001: + addiu $at, $zero, 1 + beq $t0, $at, LABEL_test_0001 + nop + SOME_LABEL: + bne $t0, $zero, LABEL_test_0001 + nop + j SOME_LABEL + nop + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Branch-Jump - Loop - Unused Label", "[optBranchJump]") { + auto res = optTranspile(R"(function test() +{ + u32<$t0> a; + loop { + if(a == 1)continue; + SOME_LABEL: + + if(a == 0)continue; + LOOP_END: + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_test_0001: + addiu $at, $zero, 1 + beq $t0, $at, LABEL_test_0001 + nop + bne $t0, $zero, LABEL_test_0001 + nop + j LABEL_test_0001 + nop + LABEL_test_0002: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_optDeadCode.cpp b/cpp/tests/test_optDeadCode.cpp new file mode 100644 index 0000000..5a6a2fb --- /dev/null +++ b/cpp/tests/test_optDeadCode.cpp @@ -0,0 +1,69 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Dead Code - Jump at end - safe", "[optDeadCode]") { + auto res = optTranspile(R"(function test() +{ + goto TEST; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + j TEST + nop)"); +} + +TEST_CASE("Optimizer E2E - Dead Code - Jump at end - unsafe with code", "[optDeadCode]") { + auto res = optTranspile(R"(function test() +{ + goto TEST; + u32 x = 2; + x = 3; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + j TEST + nop + addiu $t0, $zero, 3 + jr $ra + addiu $t0, $zero, 2)"); +} + +TEST_CASE("Optimizer E2E - Dead Code - Jump at end - unsafe", "[optDeadCode]") { + auto res = optTranspile(R"( +function test2(); +function test() +{ + test2(); + u32 x = 2; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + jal test2 + nop + jr $ra + addiu $t0, $zero, 2)"); +} + +TEST_CASE("Optimizer E2E - Dead Code - Jump in branch - safe jal", "[optDeadCode]") { + auto res = optTranspile(R"( +function test2(); +function test() +{ + u32 x = 1; + if(x) { + test2(); + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + addiu $t0, $zero, 1 + bne $t0, $zero, test2 + ori $ra, $zero, LABEL_test_0001 + LABEL_test_0001: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_optDedupeImm.cpp b/cpp/tests/test_optDedupeImm.cpp new file mode 100644 index 0000000..a0edcdb --- /dev/null +++ b/cpp/tests/test_optDedupeImm.cpp @@ -0,0 +1,104 @@ +#include +#include "pipeline.h" + +#include + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Dedupe Imm - $at cached across loads", + "[optDedupeImm]") { + auto res = optTranspile(R"( +state { + vec16 MY_VAR[4]; +} +function test() +{ + vec16 a = load(MY_VAR, 0x00); + vec16 b = load(MY_VAR, 0x10); + vec16 c = load(MY_VAR, 0x20); +} +)"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(MY_VAR) + lqv $v02, 0, 16, $at + lqv $v03, 0, 32, $at + jr $ra + lqv $v01, 0, 0, $at)"); +} + +TEST_CASE("Optimizer E2E - Dedupe Imm - $at NOT cached across branch", + "[optDedupeImm]") { + auto res = optTranspile(R"( +state { + vec16 MY_VAR[4]; +} +function test(u32 cond) +{ + vec16 a = load(MY_VAR, 0x00); + if(cond != 0) { + a += 1; + } + vec16 b = load(MY_VAR, 0x10); +} +)"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(MY_VAR) + beq $a0, $zero, LABEL_test_0001 + lqv $v01, 0, 0, $at + vaddc $v01, $v01, $v30.e7 + LABEL_test_0001: + ori $at, $zero, %lo(MY_VAR) + jr $ra + lqv $v02, 0, 16, $at)"); +} + +TEST_CASE("Optimizer E2E - Dedupe Imm - $at changes across state vars", + "[optDedupeImm]") { + auto res = optTranspile(R"( +state { + vec16 VAR_A[4]; + vec16 VAR_B[4]; +} +function test() +{ + vec16 a = load(VAR_A, 0x00); + vec16 b = load(VAR_B, 0x00); +} +)"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(VAR_A) + lqv $v01, 0, 0, $at + ori $at, $zero, %lo(VAR_B) + jr $ra + lqv $v02, 0, 0, $at)"); +} + +TEST_CASE("Optimizer E2E - Dedupe Imm - $at recached after different var", + "[optDedupeImm]") { + auto res = optTranspile(R"( +state { + vec16 VAR_A[4]; + vec16 VAR_B[4]; +} +function test() +{ + vec16 a = load(VAR_A, 0x00); + vec16 b = load(VAR_B, 0x00); + vec16 c = load(VAR_A, 0x10); +} +)"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(VAR_A) + lqv $v01, 0, 0, $at + ori $at, $zero, %lo(VAR_B) + lqv $v02, 0, 0, $at + ori $at, $zero, %lo(VAR_A) + jr $ra + lqv $v03, 0, 16, $at)"); +} diff --git a/cpp/tests/test_optDelaySlot.cpp b/cpp/tests/test_optDelaySlot.cpp new file mode 100644 index 0000000..985c257 --- /dev/null +++ b/cpp/tests/test_optDelaySlot.cpp @@ -0,0 +1,70 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Delay-Slots - Fill - Basic", "[optDelaySlot]") { + auto res = optTranspile(R"(function test(u32 dummy) +{ + u32 a = 1; + goto SOME_LABEL; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + j SOME_LABEL + addiu $t0, $zero, 1)"); +} + +TEST_CASE("Optimizer E2E - Delay-Slots - Fill - Complex", "[optDelaySlot]") { + auto res = optTranspile(R"(function test(u32 i) +{ + u32 test = 0; + while(i != 0) { + if(i == 6) { + test = 42; + break; + } + i -= 1; + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + or $t0, $zero, $zero + LABEL_test_0001: + beq $a0, $zero, LABEL_test_0002 + nop + addiu $at, $zero, 6 + bne $a0, $at, LABEL_test_0003 + nop + j LABEL_test_0002 + addiu $t0, $zero, 42 + LABEL_test_0003: + j LABEL_test_0001 + addiu $a0, $a0, 65535 + LABEL_test_0002: + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Delay-Slots - Fill across jal (scalar)", + "[optDelaySlot]") { + auto res = optTranspile(R"( +function DMAWaitIdle(); +function test() +{ + u32 a = 1; + u32 b = 2; + DMAWaitIdle(); + u32 c = 3; +} +)"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + addiu $t1, $zero, 2 + jal DMAWaitIdle + addiu $t0, $zero, 1 + jr $ra + addiu $t2, $zero, 3)"); +} diff --git a/cpp/tests/test_optDepScanCtrl.cpp b/cpp/tests/test_optDepScanCtrl.cpp new file mode 100644 index 0000000..f40cde6 --- /dev/null +++ b/cpp/tests/test_optDepScanCtrl.cpp @@ -0,0 +1,59 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" + +#include +#include +#include + +using namespace rspl; + +static std::vector> +asmLinesToDeps(std::vector &lines) { + AsmFunc func; + func.asm_ = lines; + asmInitDeps(func); + lines = std::move(func.asm_); + std::vector> res; + for (size_t i = 0; i < lines.size(); ++i) { + auto r = asmGetReorderIndices(lines, static_cast(i)); + std::sort(r.begin(), r.end()); + res.push_back(std::move(r)); + } + return res; +} + +static AsmInst makeLabel(const std::string &name) { + AsmInst inst; + inst.type = AsmType::LABEL; + inst.cold->label = name; + return inst; +} + +TEST_CASE("Optimizer - Dependency Scanner - Control - Stop at Label", + "[optDepScanCtrl]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + makeLabel("SOME_LABEL"), + asmOp("or", {"$t2", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1}, {0, 1}, {2}, {3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Control - Stop at Jump", + "[optDepScanCtrl]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("j", {"SOME_WHERE"}), + asmOp("or", {"$t2", "$zero", "$zero"}), // delay slot (filled) + asmOp("or", {"$t2", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1}, {0, 1}, {2}, {0, 1, 2, 3}, {4}}; + REQUIRE(deps == expected); +} diff --git a/cpp/tests/test_optDepScanMem.cpp b/cpp/tests/test_optDepScanMem.cpp new file mode 100644 index 0000000..598a9e5 --- /dev/null +++ b/cpp/tests/test_optDepScanMem.cpp @@ -0,0 +1,89 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" +#include "state.h" + +#include +#include +#include + +using namespace rspl; + +static std::vector> +asmLinesToDeps(std::vector &lines) { + AsmFunc func; + func.asm_ = lines; + asmInitDeps(func); + lines = std::move(func.asm_); + std::vector> res; + for (size_t i = 0; i < lines.size(); ++i) { + auto r = asmGetReorderIndices(lines, static_cast(i)); + std::sort(r.begin(), r.end()); + res.push_back(std::move(r)); + } + return res; +} + +TEST_CASE("Optimizer - Dependency Scanner - Memory - Read vs Read", + "[optDepScanMem]") { + std::vector lines = { + asmOp("lw", {"$t0", "0($s1)"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("lw", {"$t2", "0($s1)"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Memory - Read vs Write", + "[optDepScanMem]") { + std::vector lines = { + asmOp("lw", {"$t0", "0($s1)"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("sw", {"$t2", "0($s1)"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Memory - Read vs Write Barrier", + "[optDepScanMem]") { + std::vector lines = { + asmOp("lw", {"$t0", "0($s1)"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("sw", {"$t2", "0($s1)"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + // Annotate with Barrier + lines[0].cold->annotations.push_back({"Barrier", "some barrier"}); + lines[2].cold->annotations.push_back({"Barrier", "some barrier"}); + + state.reset(); + state.enterFunction("test", "command", 0); + state.pushScope("", ""); + + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1}, {0, 1, 2, 3}, {1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Memory - Write vs Write", + "[optDepScanMem]") { + std::vector lines = { + asmOp("sw", {"$t0", "0($s2)"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("sw", {"$t2", "0($s1)"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} diff --git a/cpp/tests/test_optDepScanRegs.cpp b/cpp/tests/test_optDepScanRegs.cpp new file mode 100644 index 0000000..c0dd8cc --- /dev/null +++ b/cpp/tests/test_optDepScanRegs.cpp @@ -0,0 +1,203 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" + +#include +#include +#include + +using namespace rspl; + +static std::vector> +asmLinesToDeps(std::vector &lines) { + AsmFunc func; + func.asm_ = lines; + asmInitDeps(func); + lines = std::move(func.asm_); + std::vector> res; + for (size_t i = 0; i < lines.size(); ++i) { + auto r = asmGetReorderIndices(lines, static_cast(i)); + std::sort(r.begin(), r.end()); + res.push_back(std::move(r)); + } + return res; +} + +TEST_CASE("Optimizer - Dependency Scanner - No Deps", "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("or", {"$t2", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1, 2}, {0, 1, 2}, {0, 1, 2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Basic Write Dep", "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("or", {"$t2", "$t0", "$zero"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1}, {0, 1, 2, 3}, {1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Nested Write Dep", + "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("or", {"$t2", "$t0", "$zero"}), + asmOp("or", {"$t3", "$t2", "$zero"}), + asmOp("or", {"$t4", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1}, {0, 1, 2, 3, 4}, {1, 2}, {3, 4}, {0, 1, 2, 3, 4}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - MTC2 partial write", + "[optDepScan]") { + std::vector lines = { + asmOp("vxor", {"$v25", "$v00", "$v00.e0"}), + asmOp("addiu", {"$at", "$zero", "3"}), + asmOp("mtc2", {"$at", "$v25.e6"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1, 2}, {0, 1}, {2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - MTC2 partial write no ret", + "[optDepScan]") { + std::vector lines = { + asmOp("vxor", {"$v25", "$v00", "$v00.e0"}), + asmOp("vxor", {"$v26", "$v00", "$v00"}), + asmOp("mtc2", {"$at", "$v25.e6"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1, 2}, {1, 2}, {1, 2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - MTC2 partial write ret regs", + "[optDepScan]") { + std::vector lines = { + asmOp("vxor", {"$v25", "$v00", "$v00.e0"}), + asmOp("vxor", {"$v26", "$v00", "$v00"}), + asmOp("mtc2", {"$at", "$v25.e6"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1, 2}, {1, 2}, {1, 2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Ignore Write no read simple", + "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1, 2}, {0, 1, 2}, {2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Ignore Write no read deps", + "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$t0", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0}, {1}, {2}, {3, 4}, {4}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Hidden Regs simple", + "[optDepScan]") { + std::vector lines = { + asmOp("veq", {"$v11", "$v00", "$v00"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("vmrg", {"$v01", "$v02", "$v03"}), + asmOp("or", {"$t0", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1, 2}, {0, 1, 2, 3, 4}, {0, 1, 2, 3, 4}, {1, 2, 3, 4}, {3, 4}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Regs Single-Lane", "[optDepScan]") { + std::vector lines = { + asmOp("vmov", {"$v11.e1", "$v05.e1"}), + asmOp("vmov", {"$v06.e1", "$v11.e2"}), + asmOp("vmov", {"$v07.e1", "$v11.e1"}), + asmOp("vmov", {"$v08.e1", "$v05.e1"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1}, {0, 1, 2, 3}, {1, 2, 3}, {3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Offset Syntax", "[optDepScan]") { + std::vector lines = { + asmOp("or", {"$t0", "$zero", "$zero"}), + asmOp("or", {"$t1", "$zero", "$zero"}), + asmOp("lw", {"$t2", "0($t0)"}), + asmOp("or", {"$t3", "$zero", "$zero"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0, 1}, {0, 1, 2, 3}, {1, 2, 3}, {0, 1, 2, 3}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Vector mul+add", "[optDepScan]") { + std::vector lines = { + asmOp("vmudl", {"$v27", "$v18", "$v26.v"}), + asmOp("vmadm", {"$v27", "$v17", "$v26.v"}), + asmOp("vmadn", {"$v18", "$v18", "$v25.v"}), + asmOp("vmadh", {"$v17", "$v17", "$v25.v"}), + asmOp("vaddc", {"$v18", "$v18", "$v24.v"}), + asmOp("vadd", {"$v17", "$v17", "$v23.v"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = { + {0}, {1}, {2}, {3}, {4}, {5}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Vector vabs", "[optDepScan]") { + std::vector lines = { + asmOp("vmacf", {"$v27", "$v27", "$v27"}), + asmOp("vabs", {"$v01", "$v01", "$v01"}), + asmOp("vmacf", {"$v27", "$v27", "$v27"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0}, {1}, {2}}; + REQUIRE(deps == expected); +} + +TEST_CASE("Optimizer - Dependency Scanner - Vector VCE", "[optDepScan]") { + std::vector lines = { + asmOp("vcr", {"$v01", "$v01", "$v01"}), + asmOp("ori", {"$t0", "$t0", "$t0"}), + asmOp("vcl", {"$v03", "$v03", "$v03"}), + }; + auto deps = asmLinesToDeps(lines); + std::vector> expected = {{0, 1}, {0, 1, 2}, {1, 2}}; + REQUIRE(deps == expected); +} diff --git a/cpp/tests/test_optJumpDedupe.cpp b/cpp/tests/test_optJumpDedupe.cpp new file mode 100644 index 0000000..d810749 --- /dev/null +++ b/cpp/tests/test_optJumpDedupe.cpp @@ -0,0 +1,50 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Jump Dedupe - Nested-If Used Label", "[optJumpDedupe]") { + auto res = optTranspile(R"(command<0> test() +{ + u32 a = 1; + if(a > 1) { + if(a > 10) { + a += 1; + } + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + addiu $s7, $zero, 1 + sltiu $at, $s7, 2 + bne $at, $zero, RSPQ_Loop + nop + sltiu $at, $s7, 11 + bne $at, $zero, RSPQ_Loop + nop + addiu $s7, $s7, 1 + LABEL_test_0001: + j RSPQ_Loop + nop)"); +} + +TEST_CASE("Optimizer E2E - Jump Dedupe - Nested-If Unused Label", "[optJumpDedupe]") { + auto res = optTranspile(R"(command<0> test() +{ + u32 a = 1; + while(a < 2) { + a -= 1; + } +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + addiu $s7, $zero, 1 + LABEL_test_0001: + sltiu $at, $s7, 2 + beq $at, $zero, RSPQ_Loop + nop + j LABEL_test_0001 + addiu $s7, $s7, 65535)"); +} diff --git a/cpp/tests/test_optLabels.cpp b/cpp/tests/test_optLabels.cpp new file mode 100644 index 0000000..9d36545 --- /dev/null +++ b/cpp/tests/test_optLabels.cpp @@ -0,0 +1,103 @@ +#include +#include "optimizer/patterns/dedupeLabels.h" +#include "asm.h" +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +// Helpers for direct dedupeLabels unit tests +static rspl::AsmInst L(const std::string &name) { + rspl::AsmInst inst; + inst.type = rspl::AsmType::LABEL; + inst.cold->label = name; + return inst; +} +static rspl::AsmInst O(const std::string &op, + std::vector args = {}) { + rspl::AsmInst inst; + inst.type = rspl::AsmType::OP; + inst.op = rspl::getOpcode(op); + inst.args = std::move(args); + return inst; +} +static rspl::AsmInst B(const std::string &op, + std::vector args, + const std::string &labelEnd) { + rspl::AsmInst inst = O(op, std::move(args)); + inst.cold->labelEnd = labelEnd; + return inst; +} + +TEST_CASE("Optimizer E2E - Labels - De-dupe Labels", "[optLabels]") { + auto res = optTranspile(R"(function test(u32 dummy) +{ + LABEL_A: + LABEL_B: + LABEL_C: + goto LABEL_A; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_C: + j LABEL_C + nop)"); +} + +TEST_CASE("Optimizer - dedupeLabels - consecutive labels deduped to last", + "[optLabels]") { + rspl::AsmFunc func; + func.asm_ = {B("j", {"LABEL_A"}, "LABEL_A"), O("nop"), L("LABEL_A"), + L("LABEL_B"), O("addiu", {"$t0", "$zero", "1"})}; + rspl::dedupeLabels(func); + REQUIRE(func.asm_.size() == 4); + REQUIRE(func.asm_[0].args[0] == "LABEL_B"); + REQUIRE(func.asm_[0].cold->labelEnd == "LABEL_B"); + REQUIRE(func.asm_[2].cold->label == "LABEL_B"); +} + +TEST_CASE("Optimizer - dedupeLabels - __ labels are never deduplicated", + "[optLabels]") { + rspl::AsmFunc func; + func.asm_ = {B("j", {"SKIP"}, "SKIP"), O("nop"), L("__A"), L("__A"), + L("__A"), O("addiu", {"$t0", "$zero", "1"})}; + rspl::dedupeLabels(func); + int labelCount = 0; + for (auto &inst : func.asm_) + if (inst.type == rspl::AsmType::LABEL) ++labelCount; + REQUIRE(labelCount == 3); +} + +TEST_CASE("Optimizer - dedupeLabels - __ label breaks dedup chain", + "[optLabels]") { + rspl::AsmFunc func; + func.asm_ = {B("j", {"SKIP"}, "SKIP"), O("nop"), L("SKIP"), L("__B"), + L("SKIP"), O("addiu", {"$t0", "$zero", "1"})}; + rspl::dedupeLabels(func); + bool hasDunderB = false; + for (auto &inst : func.asm_) + if (inst.cold->label == "__B") hasDunderB = true; + REQUIRE(hasDunderB); +} + +TEST_CASE("Optimizer E2E - Labels - De-dupe Labels - keep single", "[optLabels]") { + auto res = optTranspile(R"(function test(u32 dummy) +{ + LABEL_A: + dummy += 1; + LABEL_B: + dummy += 2; + LABEL_C: + goto LABEL_A; +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + LABEL_A: + addiu $a0, $a0, 1 + LABEL_B: + addiu $a0, $a0, 2 + LABEL_C: + j LABEL_A + nop)"); +} diff --git a/cpp/tests/test_optMergeSequence.cpp b/cpp/tests/test_optMergeSequence.cpp new file mode 100644 index 0000000..03040e4 --- /dev/null +++ b/cpp/tests/test_optMergeSequence.cpp @@ -0,0 +1,53 @@ +#include +#include "pipeline.h" + +static rspl::TranspileResult optTranspile(const std::string &src) { + return rspl::transpileSource(src, {.rspqWrapper = false, .optimize = true}); +} + +TEST_CASE("Optimizer E2E - Merge Sequence - Multiply - Zero fractional", "[optMergeSequence]") { + auto res = optTranspile(R"( +state { vec16 SCREEN_SCALE_OFFSET; } +function test(u32 dummy) +{ + vec32 screenSize; + screenSize:sint = load(SCREEN_SCALE_OFFSET); + screenSize:sfract = 0; + screenSize >>= 8; + END: +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(SCREEN_SCALE_OFFSET) + lqv $v01, 0, 0, $at + vmudl $v02, $v00, $v31.e7 + vmadm $v01, $v01, $v31.e7 + vmadn $v02, $v00, $v00 + END: + jr $ra + nop)"); +} + +TEST_CASE("Optimizer E2E - Merge Sequence - Multiply - Non-Zero fractional (no opt)", "[optMergeSequence]") { + auto res = optTranspile(R"( +state { vec16 SCREEN_SCALE_OFFSET; } +function test(u32 dummy) +{ + vec32 screenSize; + screenSize:sint = load(SCREEN_SCALE_OFFSET); + screenSize:sfract = 1; + screenSize >>= 8; + END: +})"); + REQUIRE(res.warn.empty()); + REQUIRE(res.asm_ == R"(test: + ori $at, $zero, %lo(SCREEN_SCALE_OFFSET) + lqv $v01, 0, 0, $at + vxor $v02, $v00, $v30.e7 + vmudl $v02, $v02, $v31.e7 + vmadm $v01, $v01, $v31.e7 + vmadn $v02, $v00, $v00 + END: + jr $ra + nop)"); +} diff --git a/cpp/tests/test_optRegScan.cpp b/cpp/tests/test_optRegScan.cpp new file mode 100644 index 0000000..93dd715 --- /dev/null +++ b/cpp/tests/test_optRegScan.cpp @@ -0,0 +1,220 @@ +#include +#include "asm.h" +#include "optimizer/asm_scan_deps.h" + +#include +#include +#include + +using namespace rspl; + +static std::vector allLanes(const std::string ®, int start = 0, + int count = 8) { + std::vector r; + for (int i = 0; i < count; ++i) + r.push_back(reg + "_" + std::to_string((start + i) % 8)); + return r; +} + +static std::string vReg(int n) { + return std::string("$v") + (n < 10 ? "0" : "") + std::to_string(n); +} + +// STV/LTV lanes: 8 consecutive registers starting at base, each at lane +// (8 + i - row) % 8 where row = element_arg / 2. +static std::vector stvLanes(int base, int element) { + int row = element / 2; + std::vector r; + for (int i = 0; i < 8; ++i) + r.push_back(vReg(base + i) + "_" + + std::to_string((8 + i - row) % 8)); + return r; +} + +static std::vector vecRange(const std::string &, int start, + int count) { + std::vector r; + for (int i = 0; i < count; ++i) + r.push_back(vReg(start + i)); + return r; +} + +static AsmInst makeAsm(const std::string &op, + const std::vector &args) { + AsmInst inst = asmOp(op, args); + asmInitDep(inst); + return inst; +} + +static auto sorted(std::vector v) { + std::sort(v.begin(), v.end()); + return v; +} +template static auto sorted(const C &c) { + std::vector v(c.begin(), c.end()); + std::sort(v.begin(), v.end()); + return v; +} + +static std::vector idxs(const std::vector ®s) { + std::vector r; + for (const auto &s : regs) r.push_back(getRegIndex(s)); + return r; +} + +static std::vector stallIdxs(const std::vector ®s) { + std::vector r; + for (const auto &s : regs) r.push_back(getRegStallIndex(s)); + return r; +} + +#define VEC(n) "$v" #n + +TEST_CASE("Optimizer - Register Scanner", "[optRegScan]") { + // Logic + { + auto a = makeAsm("or", {"$t0", "$a1", "$a0"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$a1"), getRegIndex("$a0")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted({getRegIndex("$t0")})); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$a1"), getRegStallIndex("$a0")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({getRegStallIndex("$t0")})); + } + + // Arith + { + auto a = makeAsm("addiu", {"$t0", "$t1", "4"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$t1")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted({getRegIndex("$t0")})); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$t1")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({getRegStallIndex("$t0")})); + } + + // Vec Store + { + auto a = makeAsm("sdv", {"$v08", "0", "16", "$s6"}); + auto expSrc = allLanes("$v08"); + expSrc.push_back("$s6"); + auto expSrcIdx = sorted([&]() { + std::vector r; + for (auto &s : expSrc) r.push_back(getRegIndex(s)); + return r; + }()); + auto expStall = std::vector{"$v08", "$s6"}; + auto expStallIdx = sorted([&]() { + std::vector r; + for (auto &s : expStall) r.push_back(getRegStallIndex(s)); + return r; + }()); + REQUIRE(sorted(a.depsSourceIdx) == expSrcIdx); + REQUIRE(sorted(a.depsTargetIdx) == sorted({})); + REQUIRE(sorted(a.depsStallSourceIdx) == expStallIdx); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } + + // Vec packed Store + { + auto a = makeAsm("sfv", {"$v08", "0", "$s6"}); + auto expSrc = allLanes("$v08"); + expSrc.push_back("$s6"); + auto expSrcIdx = sorted([&]() { + std::vector r; + for (auto &s : expSrc) r.push_back(getRegIndex(s)); + return r; + }()); + auto expStall = std::vector{"$v08", "$s6"}; + auto expStallIdx = sorted([&]() { + std::vector r; + for (auto &s : expStall) r.push_back(getRegStallIndex(s)); + return r; + }()); + REQUIRE(sorted(a.depsSourceIdx) == expSrcIdx); + REQUIRE(sorted(a.depsTargetIdx) == sorted({})); + REQUIRE(sorted(a.depsStallSourceIdx) == expStallIdx); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } + + // Lanes - Vec move + { + auto a = makeAsm("vmov", {"$v07.e3", "$v05.e2"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$v05_2")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted({getRegIndex("$v07_3"), getRegIndex("$acc")})); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$v05")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({getRegStallIndex("$v07")})); + } + + // Lanes - STV - base: $v16_0..$v23_7 + { + auto a = makeAsm("stv", {"$v16", "0", "0", "$t0"}); + auto expSrc = stvLanes(16, 0); // element=0 + expSrc.push_back("$t0"); + auto expSrcIdx = sorted(idxs(expSrc)); + REQUIRE(sorted(a.depsSourceIdx) == expSrcIdx); + REQUIRE(sorted(a.depsTargetIdx) == sorted({})); + auto expStall = vecRange("$v16", 16, 8); + expStall.push_back("$t0"); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted(stallIdxs(expStall))); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } + + // Lanes - STV - offset 2: $v08_7, $v09_0..$v15_6 + { + auto a = makeAsm("stv", {"$v08", "2", "0x10", "$t0"}); + auto expSrc = stvLanes(8, 2); + expSrc.push_back("$t0"); + REQUIRE(sorted(a.depsSourceIdx) == sorted(idxs(expSrc))); + REQUIRE(sorted(a.depsTargetIdx) == sorted({})); + auto expStall = vecRange("$v08", 8, 8); + expStall.push_back("$t0"); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted(stallIdxs(expStall))); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } + + // Lanes - STV - offset 8: $v08_4, $v09_5..$v15_3 + { + auto a = makeAsm("stv", {"$v08", "8", "0x20", "$t0"}); + auto expSrc = stvLanes(8, 8); + expSrc.push_back("$t0"); + REQUIRE(sorted(a.depsSourceIdx) == sorted(idxs(expSrc))); + REQUIRE(sorted(a.depsTargetIdx) == sorted({})); + auto expStall = vecRange("$v08", 8, 8); + expStall.push_back("$t0"); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted(stallIdxs(expStall))); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } + + // Lanes - LTV - base: $v16_0..$v23_7 + { + auto a = makeAsm("ltv", {"$v16", "0", "0", "$t0"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$t0")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted(idxs(stvLanes(16, 0)))); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$t0")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted(stallIdxs(vecRange("$v16", 16, 8)))); + } + + // Lanes - LTV - offset 2: $v08_7, $v09_0..$v15_6 + { + auto a = makeAsm("ltv", {"$v08", "2", "0x10", "$t0"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$t0")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted(idxs(stvLanes(8, 2)))); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$t0")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted(stallIdxs(vecRange("$v08", 8, 8)))); + } + + // Lanes - LTV - offset 8: $v00_4..$v07_3 + { + auto a = makeAsm("ltv", {"$v00", "8", "0x20", "$t0"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$t0")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted(idxs(stvLanes(0, 8)))); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$t0")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted(stallIdxs(vecRange("$v00", 0, 8)))); + } + + // ctc2 - VCC + { + auto a = makeAsm("ctc2", {"$at", "$vcc"}); + REQUIRE(sorted(a.depsSourceIdx) == sorted({getRegIndex("$at")})); + REQUIRE(sorted(a.depsTargetIdx) == sorted({getRegIndex("$vcc")})); + REQUIRE(sorted(a.depsStallSourceIdx) == sorted({getRegStallIndex("$at")})); + REQUIRE(sorted(a.depsStallTargetIdx) == sorted({})); + } +} diff --git a/cpp/tests/test_preproc.cpp b/cpp/tests/test_preproc.cpp new file mode 100644 index 0000000..af02654 --- /dev/null +++ b/cpp/tests/test_preproc.cpp @@ -0,0 +1,226 @@ +#include +#include "preproc.h" + +#include +#include + +static std::string preproc(const std::string &src) { + std::unordered_map defines; + return rspl::preprocFull(src, defines, "."); +} + +TEST_CASE("Preproc - Define - Basic", "[preproc]") { + auto src = R"( + #define TEST 42 + macro test() { + u32 x = TEST; + } + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = 42;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Multiple", "[preproc]") { + auto src = R"( + #define TEST 42 + #define TEST_AB 43 + + macro test() { + u32 x = TEST; + u32 y = TEST_AB; + } + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = 42;") != std::string::npos); + REQUIRE(res.find("u32 y = 43;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Deps", "[preproc]") { + auto src = R"( + #define TEST 42 + #define TEST_AB TEST+1 + + macro test() { + u32 x = TEST; + u32 y = TEST_AB; + } + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = 42;") != std::string::npos); + REQUIRE(res.find("u32 y = 42+1;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Partial", "[preproc]") { + auto src = R"( + #define my 42 + + macro my_function() { + u32 x = my; + } + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = 42;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Undef", "[preproc]") { + auto src = R"( + #define TEST 42 + macro test() { + u32 x = TEST; + } + #undef TEST + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = 42;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Undef Before usage", "[preproc]") { + auto src = R"( + #define TEST 42 + #undef TEST + + macro test() { + u32 x = TEST; + } + )"; + auto res = preproc(src); + REQUIRE(res.find("u32 x = TEST;") != std::string::npos); +} + +TEST_CASE("Preproc - Define - Empty", "[preproc]") { + auto src = R"( + #define + macro test() { + u32 x = TEST; + } + )"; + REQUIRE_THROWS_AS(preproc(src), std::runtime_error); + try { + preproc(src); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Invalid #define statement") != std::string::npos); + } +} + +TEST_CASE("Preproc - Ifdef - Basic", "[preproc]") { + auto src = R"( + #define TEST 42 + + #ifdef TEST2 + macro test2() {} + #endif + + #ifdef TEST + macro test() {} + #endif + )"; + auto res = preproc(src); + REQUIRE(res.find("macro test() {}") != std::string::npos); + REQUIRE(res.find("macro test2() {}") == std::string::npos); +} + +TEST_CASE("Preproc - Ifdef - Else", "[preproc]") { + auto src = R"( + #define TEST 42 + + #ifdef TEST2 + macro test2() {} + #else + macro test() {} + #endif + )"; + auto res = preproc(src); + REQUIRE(res.find("macro test() {}") != std::string::npos); + REQUIRE(res.find("macro test2() {}") == std::string::npos); +} + +TEST_CASE("Preproc - Ifdef - define (true)", "[preproc]") { + auto src = R"( + #define TEST 42 + + #ifdef TEST + #define VAL 1 + #else + #define VAL 2 + #endif + VAL + )"; + auto res = preproc(src); + // After preprocessing, VAL should be 1 + REQUIRE(res.find("1") != std::string::npos); +} + +TEST_CASE("Preproc - Ifdef - define (false)", "[preproc]") { + auto src = R"( + #define TEST 42 + + #ifdef TEST_OTHER + #define VAL 1 + #else + #define VAL 2 + #endif + VAL + )"; + auto res = preproc(src); + REQUIRE(res.find("2") != std::string::npos); +} + +TEST_CASE("Preproc - Ifdef - nested", "[preproc]") { + auto src = R"( + #ifdef TEST + #ifdef TEST2 + #endif + #endif + + )"; + REQUIRE_THROWS_AS(preproc(src), std::runtime_error); + try { + preproc(src); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Nested #ifdef") != std::string::npos); + } +} + +TEST_CASE("Preproc - Defines emitted in source order", "[preproc]") { + std::unordered_map defines; + std::vector defineOrder; + auto src = R"( +#define TRI_BUFFER_COUNT 70 +#define LIGHT_COUNT 8 +)"; + rspl::preprocFull(src, defines, ".", &defineOrder); + REQUIRE(defineOrder.size() == 2); + REQUIRE(defineOrder[0].name == "TRI_BUFFER_COUNT"); + REQUIRE(defineOrder[1].name == "LIGHT_COUNT"); +} + +TEST_CASE("Preproc - #undef removes define from ordered output", + "[preproc]") { + std::unordered_map defines; + std::vector defineOrder; + auto src = R"( +#define KEEP_ME 42 +#define REMOVE_ME 99 +#undef REMOVE_ME +)"; + rspl::preprocFull(src, defines, ".", &defineOrder); + // Both were pushed in order, but REMOVE_ME is gone from map + REQUIRE(defines.count("KEEP_ME") == 1); + REQUIRE(defines.count("REMOVE_ME") == 0); + REQUIRE(defineOrder.size() == 2); // both were pushed before undef +} + +TEST_CASE("Preproc - stripComments handles large block comments", + "[preproc]") { + auto src = R"(/*************************************** + * Multi-line block comment + ***************************************/ +#define TEST_VALUE 42 +)"; + std::unordered_map defines; + std::string result = rspl::preprocFull(src, defines, "."); + REQUIRE(defines.count("TEST_VALUE") == 1); + REQUIRE(defines["TEST_VALUE"].value == "42"); +} diff --git a/cpp/tests/test_scalarOps.cpp b/cpp/tests/test_scalarOps.cpp new file mode 100644 index 0000000..37b2546 --- /dev/null +++ b/cpp/tests/test_scalarOps.cpp @@ -0,0 +1,250 @@ +#include +#include "pipeline.h" + +#include + +static const auto CONF = rspl::TranspileConfig{.rspqWrapper = false}; + +TEST_CASE("Scalar Ops - 32-Bit Arithmetic", "[scalarOps]") { + auto result = rspl::transpileSource(R"(state { u32 TEST_CONST; } +function test_scalar_ops() +{ + u32<$t0> a, b, c; + s32<$t3> sa, sb, sc; + + ADD: + c = a + b; sc = sa + sb; + c = a + 1; sc = sa + 1; + c = a + TEST_CONST; sc = sa + TEST_CONST; + + SUB: + c = a - b; sc = sa - sb; + c = a - 1; sc = sa - 1; + //c = a - TEST_CONST; sc = sa - TEST_CONST; Invalid + + MUL: + c = a * 4; + + DIV: + c = a / 8; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_ops: + ADD: + addu $t2, $t0, $t1 + addu $t5, $t3, $t4 + addiu $t2, $t0, 1 + addiu $t5, $t3, 1 + addiu $t2, $t0, %lo(TEST_CONST) + addiu $t5, $t3, %lo(TEST_CONST) + SUB: + subu $t2, $t0, $t1 + subu $t5, $t3, $t4 + addiu $t2, $t0, 65535 + addiu $t5, $t3, 65535 + MUL: + sll $t2, $t0, 2 + DIV: + srl $t2, $t0, 3 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - 32-Bit Logic", "[scalarOps]") { + auto result = rspl::transpileSource(R"(state { u32 TEST_CONST; } +function test_scalar_ops() +{ + u32<$t0> a, b, c; + s32<$t3> sa, sb, sc; + + AND: + c = a & b; + c = a & 1; + c = a & TEST_CONST; + + OR: + c = a | b; + c = a | 2; + c = a | TEST_CONST; + + XOR: + c = a ^ b; + c = a ^ 2; + c = a ^ TEST_CONST; + + NOT: + c = ~b; + + NOR: + c = a ~| b; + + SHIFT_LEFT: + c = a << b; + c = a << 2; + //c = a << TEST_CONST; Invalid + + SHIFT_RIGHT: + c = a >> b; + c = a >> 2; + sc = sa >> sb; + sc = sa >> 2; + sc = sa >>> 2; + + //c = a >> TEST_CONST; Invalid +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_ops: + AND: + and $t2, $t0, $t1 + andi $t2, $t0, 1 + andi $t2, $t0, %lo(TEST_CONST) + OR: + or $t2, $t0, $t1 + ori $t2, $t0, 2 + ori $t2, $t0, %lo(TEST_CONST) + XOR: + xor $t2, $t0, $t1 + xori $t2, $t0, 2 + xori $t2, $t0, %lo(TEST_CONST) + NOT: + nor $t2, $zero, $t1 + NOR: + nor $t2, $t0, $t1 + SHIFT_LEFT: + sllv $t2, $t0, $t1 + sll $t2, $t0, 2 + SHIFT_RIGHT: + srlv $t2, $t0, $t1 + srl $t2, $t0, 2 + srav $t5, $t3, $t4 + sra $t5, $t3, 2 + srl $t5, $t3, 2 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Multiplication (2^x)", "[scalarOps]") { + auto result = rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = b * 4; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + sll $t0, $t1, 2 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Division (2^x)", "[scalarOps]") { + auto result = rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = b / 8; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + srl $t0, $t1, 3 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Assign scalar", "[scalarOps]") { + auto result = rspl::transpileSource(R"(function test() { + u32<$t0> a; + u32<$t1> b = a; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + or $t1, $zero, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Assign Vector (ufract)", "[scalarOps]") { + auto result = rspl::transpileSource(R"(function test() { + vec32 v0; + vec16 v1; + u32 a = v0:ufract.y; + u32 b = v1:ufract.y; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mfc2 $t0, $v02.e1 + mfc2 $t1, $v03.e1 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Assign Vector (sint)", "[scalarOps]") { + auto result = rspl::transpileSource(R"(function test() { + vec32 v0; + vec16 v1; + u32 a = v0:sint.y; + u32 b = v1:sint.y; +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mfc2 $t0, $v01.e1 + mfc2 $t1, $v03.e1 + jr $ra + nop)"); +} + +TEST_CASE("Scalar Ops - Invalid (multiplication)", "[scalarOps]") { + REQUIRE_THROWS_AS( + rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = a * b; +})", CONF), + std::runtime_error); + try { + rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = a * b; +})", CONF); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Scalar-Multiplication only allowed with a power-of-two") != std::string::npos); + } +} + +TEST_CASE("Scalar Ops - Invalid (division)", "[scalarOps]") { + REQUIRE_THROWS_AS( + rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = a / b; +})", CONF), + std::runtime_error); + try { + rspl::transpileSource(R"(function test() { + u32<$t0> a, b; + a = a / b; +})", CONF); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Scalar-Division only allowed with a power-of-two") != std::string::npos); + } +} + +TEST_CASE("Scalar Ops - Invalid (sub with label)", "[scalarOps]") { + REQUIRE_THROWS_AS( + rspl::transpileSource(R"(state { u32 TEST_CONST; } +function test() { + u32<$t0> a; + a = a - TEST_CONST; +})", CONF), + std::runtime_error); + try { + rspl::transpileSource(R"(state { u32 TEST_CONST; } +function test() { + u32<$t0> a; + a = a - TEST_CONST; +})", CONF); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Subtraction cannot use labels") != std::string::npos); + } +} diff --git a/cpp/tests/test_scope.cpp b/cpp/tests/test_scope.cpp new file mode 100644 index 0000000..c5a4dd8 --- /dev/null +++ b/cpp/tests/test_scope.cpp @@ -0,0 +1,83 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Scope - Var Declaration", "[scope]") { + auto result = rspl::transpileSource( + R"(function test_scope() +{ + u32<$t0> a; + { + u32<$t1> b; + b += 2; + } // 'b' is no longer defined now + a += 2; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scope: + addiu $t1, $t1, 2 + addiu $t0, $t0, 2 + jr $ra + nop)"); +} + +TEST_CASE("Scope - Var Un-Declaration", "[scope]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test_scope() +{ + u32<$t0> a; + a += 2; + undef a; + a = 2; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test_scope() +{ + u32<$t0> a; + a += 2; + undef a; + a = 2; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Variable a not known") != std::string::npos); + } +} + +TEST_CASE("Scope - Var Decl. invalid", "[scope]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test_scope() +{ + u32<$t0> a; + { + u32<$t1> b; + b += 2; + } + b += 2; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test_scope() +{ + u32<$t0> a; + { + u32<$t1> b; + b += 2; + } + b += 2; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Variable b not known") != std::string::npos); + } +} diff --git a/cpp/tests/test_state.cpp b/cpp/tests/test_state.cpp new file mode 100644 index 0000000..545de48 --- /dev/null +++ b/cpp/tests/test_state.cpp @@ -0,0 +1,95 @@ +#include +#include "state.h" +#include "registers.h" +#include "types.h" + +using namespace rspl; + +TEST_CASE("state basic lifecycle", "[state]") { + state.reset(); + state.enterFunction("test", "function", 0); + + REQUIRE(state.func == "test"); + REQUIRE(state.funcType == "function"); + REQUIRE(state.varExists("ZERO")); + REQUIRE(state.varExists("VZERO")); + REQUIRE(state.varExists("RA")); + + state.leaveFunction(); + REQUIRE(state.func.empty()); +} + +TEST_CASE("state register allocation scalar", "[state]") { + state.reset(); + state.enterFunction("test", "function", 0); + + std::string reg = state.allocRegister("u32"); + REQUIRE(!reg.empty()); + REQUIRE(!reg::isVecReg(reg)); // Scalar type gets scalar register + + state.declareVar("a", "u32", reg); + REQUIRE(state.varExists("a")); + + // Register is marked used; next allocation gets a different one + std::string reg2 = state.allocRegister("u32"); + REQUIRE(reg2 != reg); + + state.leaveFunction(); +} + +TEST_CASE("state register allocation vector", "[state]") { + state.reset(); + state.enterFunction("test", "function", 0); + + std::string reg = state.allocRegister("vec16"); + REQUIRE(!reg.empty()); + REQUIRE(reg::isVecReg(reg)); + + state.leaveFunction(); +} + +TEST_CASE("state scope push/pop", "[state]") { + state.reset(); + state.enterFunction("test", "function", 0); + + std::string reg = state.allocRegister("u32"); + state.declareVar("outer", "u32", reg); + + state.pushScope(); + REQUIRE(state.varExists("outer")); // Inherited from parent + state.declareVar("inner", "u32", state.allocRegister("u32")); + REQUIRE(state.varExists("inner")); + state.popScope(); + + REQUIRE(state.varExists("outer")); + REQUIRE(!state.varExists("inner")); // Gone after pop + + state.leaveFunction(); +} + +TEST_CASE("state label generation", "[state]") { + state.reset(); + state.enterFunction("myFunc", "function", 0); + + std::string label1 = state.generateLabel(); + std::string label2 = state.generateLabel(); + REQUIRE(label1 != label2); + REQUIRE(label1.find("LABEL_myFunc_") == 0); + + state.leaveFunction(); +} + +TEST_CASE("state const and modify tracking", "[state]") { + state.reset(); + state.enterFunction("test", "function", 0); + + std::string reg = state.allocRegister("u32"); + state.declareVar("x", "u32", reg, true); // const + + state.markVarModified("x"); + const VarDef *v = state.getRequiredVar("x", "test", ""); + REQUIRE(v->modifyCount == 1); + REQUIRE(v->isConst == true); + + state.leaveFunction(); +} diff --git a/cpp/tests/test_stateDataBss.cpp b/cpp/tests/test_stateDataBss.cpp new file mode 100644 index 0000000..3a14853 --- /dev/null +++ b/cpp/tests/test_stateDataBss.cpp @@ -0,0 +1,224 @@ +#include +#include "pipeline.h" + +#include + +static std::string getDataSection(const std::string &asm_) { + auto idxData = asm_.find(".data"); + auto idxText = asm_.find(".text"); + if (idxData == std::string::npos || idxText == std::string::npos) + return ""; + return asm_.substr(idxData, idxText - idxData); +} + +TEST_CASE("State - Empty State", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state {} + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("RSPQ_EmptySavedState") != std::string::npos); +} + +TEST_CASE("State - Types", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u8 a; + u16 b; + u32 c; + vec16 d; + vec32 e; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("RSPQ_BeginSavedState") != std::string::npos); + REQUIRE(data.find("STATE_MEM_START:") != std::string::npos); + REQUIRE(data.find("a: .ds.b 1") != std::string::npos); + REQUIRE(data.find("b: .ds.b 2") != std::string::npos); + REQUIRE(data.find("c: .ds.b 4") != std::string::npos); + REQUIRE(data.find("d: .ds.b 16") != std::string::npos); + REQUIRE(data.find("e: .ds.b 32") != std::string::npos); + REQUIRE(data.find("STATE_MEM_END:") != std::string::npos); +} + +TEST_CASE("State - Arrays", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u32 a0[1]; + u32 a1[4]; + u32 a2[2][4]; + vec32 b0[1]; + vec32 b1[2]; + vec32 b2[4][2]; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("a0: .ds.b 4") != std::string::npos); + REQUIRE(data.find("a1: .ds.b 16") != std::string::npos); + REQUIRE(data.find("a2: .ds.b 32") != std::string::npos); + REQUIRE(data.find("b0: .ds.b 32") != std::string::npos); + REQUIRE(data.find("b1: .ds.b 64") != std::string::npos); + REQUIRE(data.find("b2: .ds.b 256") != std::string::npos); +} + +TEST_CASE("State - Extern", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u32 a; + extern u32 b; + u32 c; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("a: .ds.b 4") != std::string::npos); + REQUIRE(data.find("c: .ds.b 4") != std::string::npos); + REQUIRE(data.find("b:") == std::string::npos); +} + +TEST_CASE("State - Align", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u16 a; + alignas(8) u16 b; + alignas(4) u8 c; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("a: .ds.b 2") != std::string::npos); + REQUIRE(data.find("b: .ds.b 2") != std::string::npos); + REQUIRE(data.find("c: .ds.b 1") != std::string::npos); +} + +TEST_CASE("State - Align lower", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + vec16 VEC_A; + alignas(8) vec16 VEC_A; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("RSPQ_BeginSavedState") != std::string::npos); + REQUIRE(data.find("STATE_MEM_START:") != std::string::npos); + REQUIRE(data.find(".align 4") != std::string::npos); + REQUIRE(data.find("VEC_A: .ds.b 16") != std::string::npos); + REQUIRE(data.find("STATE_MEM_END:") != std::string::npos); +} + +TEST_CASE("State - Data State", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + data { + u32 BBB; + u32 CCC; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("RSPQ_EmptySavedState") != std::string::npos); + REQUIRE(data.find("BBB: .ds.b 4") != std::string::npos); + REQUIRE(data.find("CCC: .ds.b 4") != std::string::npos); +} + +TEST_CASE("State - BSS Only", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + bss { + u32 DDD; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find(".bss") != std::string::npos); + REQUIRE(result.asm_.find("DDD: .ds.b 4") != std::string::npos); +} + +TEST_CASE("State - Data + State", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u32 AAA; + } + data { + u32 BBB; + u32 CCC; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("AAA: .ds.b 4") != std::string::npos); + REQUIRE(data.find("BBB: .ds.b 4") != std::string::npos); + REQUIRE(data.find("CCC: .ds.b 4") != std::string::npos); +} + +TEST_CASE("State - Data + State + BSS", "[stateDataBss]") { + auto result = rspl::transpileSource( + R"( + state { + u32 AAA; + } + data { + u32 BBB; + u32 CCC; + } + bss { + u32 DDD; + } + )", + {.rspqWrapper = true}); + + REQUIRE(result.warn.empty()); + auto data = getDataSection(result.asm_); + REQUIRE(data.find("AAA: .ds.b 4") != std::string::npos); + REQUIRE(data.find("BBB: .ds.b 4") != std::string::npos); + REQUIRE(data.find("CCC: .ds.b 4") != std::string::npos); + REQUIRE(result.asm_.find(".bss") != std::string::npos); + REQUIRE(result.asm_.find("DDD: .ds.b 4") != std::string::npos); +} + +TEST_CASE("State - Extern variables are registered for lookup", + "[stateDataBss]") { + std::string src = R"( +state { + extern u32 RDPQ_CMD_STAGING; + extern u16 RSPQ_Loop; + vec16 MY_VAR; +} +function test(u32 dummy) +{ + u32 x = RDPQ_CMD_STAGING; + u32 y = RSPQ_Loop; +} +)"; + auto result = rspl::transpileSource(src, {.rspqWrapper = false}); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_.find("%lo(RDPQ_CMD_STAGING)") != std::string::npos); + REQUIRE(result.asm_.find("%lo(RSPQ_Loop)") != std::string::npos); +} diff --git a/cpp/tests/test_store.cpp b/cpp/tests/test_store.cpp new file mode 100644 index 0000000..c2984ae --- /dev/null +++ b/cpp/tests/test_store.cpp @@ -0,0 +1,301 @@ +#include +#include "pipeline.h" +#include + +static const auto CONF = rspl::TranspileConfig{.rspqWrapper = false}; + +TEST_CASE("Store - Scalar 32-Bit", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_scalar_store() +{ + u32<$t0> val, dst; + + store(val, dst); + store(val, dst, 0x10); + store(val, dst, TEST_CONST); + + store(val, TEST_CONST); + store(val, TEST_CONST, 0x10); +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_store: + sw $t0, ($t1) + sw $t0, 16($t1) + sw $t0, %lo(TEST_CONST)($t1) + sw $t0, %lo(TEST_CONST)($zero) + sw $t0, %lo(16 + TEST_CONST)($zero) + jr $ra + nop)"); +} + +TEST_CASE("Store - Scalar Cast", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_scalar_store() +{ + u32<$t0> val, dst; + + store(val:u32, dst); + store(val:u16, dst); + store(val:u8, dst); + + store(val:s32, dst); + store(val:s16, dst); + store(val:s8, dst); +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_scalar_store: + sw $t0, ($t1) + sh $t0, ($t1) + sb $t0, ($t1) + sw $t0, ($t1) + sh $t0, ($t1) + sb $t0, ($t1) + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector 32-Bit", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_store() +{ + u32<$t0> dst; + vec32<$v01> val; + + WholeVector: + store(val, dst); + store(val, TEST_CONST); + + Swizzle: + store(val.y, dst); + store(val.z, dst, 0x10); + store(val.zw, dst); + store(val.zw, dst, 0x10); + store(val.XYZW, dst); + store(val.XYZW, dst, 0x10); +})", CONF); + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_store: + WholeVector: + sqv $v01, 0, 0, $t0 + sqv $v02, 0, 16, $t0 + ori $at, $zero, %lo(TEST_CONST) + sqv $v01, 0, 0, $at + sqv $v02, 0, 16, $at + Swizzle: + ssv $v01, 2, 0, $t0 + ssv $v02, 2, 2, $t0 + ssv $v01, 4, 16, $t0 + ssv $v02, 4, 18, $t0 + slv $v01, 4, 0, $t0 + slv $v02, 4, 4, $t0 + slv $v01, 4, 16, $t0 + slv $v02, 4, 20, $t0 + sdv $v01, 8, 0, $t0 + sdv $v02, 8, 8, $t0 + sdv $v01, 8, 16, $t0 + sdv $v02, 8, 24, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector 32-Bit Unaligned", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_store() +{ + u32<$t0> dst; + vec32<$v01> val; + + WholeVector: + store_unaligned(val, dst); + store_unaligned(val, TEST_CONST); + + Swizzle: + store_unaligned(val.y, dst); + store_unaligned(val.z, dst, 0x10); + store_unaligned(val.zw, dst); + store_unaligned(val.zw, dst, 0x10); + store_unaligned(val.XYZW, dst); + store_unaligned(val.XYZW, dst, 0x10); +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_store: + WholeVector: + sqv $v01, 0, 0, $t0 + srv $v01, 0, 16, $t0 + sqv $v02, 0, 16, $t0 + srv $v02, 0, 32, $t0 + ori $at, $zero, %lo(TEST_CONST) + sqv $v01, 0, 0, $at + srv $v01, 0, 16, $at + sqv $v02, 0, 16, $at + srv $v02, 0, 32, $at + Swizzle: + ssv $v01, 2, 0, $t0 + ssv $v02, 2, 2, $t0 + ssv $v01, 4, 16, $t0 + ssv $v02, 4, 18, $t0 + slv $v01, 4, 0, $t0 + slv $v02, 4, 4, $t0 + slv $v01, 4, 16, $t0 + slv $v02, 4, 20, $t0 + sdv $v01, 8, 0, $t0 + sdv $v02, 8, 8, $t0 + sdv $v01, 8, 16, $t0 + sdv $v02, 8, 24, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector 16-Bit Unaligned", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_store() +{ + u32<$t0> dst; + vec16<$v01> val; + + WholeVector: + store_unaligned(val, dst); + store_unaligned(val, TEST_CONST); + + Swizzle: + store_unaligned(val.y, dst); + store_unaligned(val.z, dst, 0x10); + store_unaligned(val.zw, dst); + store_unaligned(val.zw, dst, 0x10); + store_unaligned(val.XYZW, dst); + store_unaligned(val.XYZW, dst, 0x10); +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_store: + WholeVector: + sqv $v01, 0, 0, $t0 + srv $v01, 0, 16, $t0 + ori $at, $zero, %lo(TEST_CONST) + sqv $v01, 0, 0, $at + srv $v01, 0, 16, $at + Swizzle: + ssv $v01, 2, 0, $t0 + ssv $v01, 4, 16, $t0 + slv $v01, 4, 0, $t0 + slv $v01, 4, 16, $t0 + sdv $v01, 8, 0, $t0 + sdv $v01, 8, 16, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector Cast", "[store]") { + auto result = rspl::transpileSource(R"( state { u32 TEST_CONST; } +function test_vector_store() +{ + u32<$t0> dst; + vec32<$v01> val; + + WholeVector: + store(val:uint, dst); + store(val:ufract, dst); + + Swizzle: + store(val:uint.XYZW, dst); + store(val:ufract.XYZW, dst); +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test_vector_store: + WholeVector: + sqv $v01, 0, 0, $t0 + sqv $v02, 0, 0, $t0 + Swizzle: + sdv $v01, 8, 0, $t0 + sdv $v02, 8, 0, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector Packed", "[store]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> dst; + vec16<$v01> val; + + Unsigned: + store_vec_u8(val, dst); + store_vec_u8(val, dst, 0x10); + store_vec_u8(val.y, dst); + store_vec_u8(val.z, dst, 0x10); + + Signed: + store_vec_s8(val, dst); + store_vec_s8(val, dst, 0x10); + store_vec_s8(val.y, dst); + store_vec_s8(val.z, dst, 0x10); +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + Unsigned: + suv $v01, 0, 0, $t0 + suv $v01, 0, 16, $t0 + suv $v01, 1, 0, $t0 + suv $v01, 2, 16, $t0 + Signed: + spv $v01, 0, 0, $t0 + spv $v01, 0, 16, $t0 + spv $v01, 1, 0, $t0 + spv $v01, 2, 16, $t0 + jr $ra + nop)"); +} + +TEST_CASE("Store - Vector Transposed", "[store]") { + auto result = rspl::transpileSource(R"(function test() +{ + u32<$t0> ptr; + vec16<$v08> a; + vec16<$v16> b; + + store_transposed(a, 0, ptr, 0x00); + store_transposed(a, 0, ptr); + store_transposed(a, 1, ptr, 0x10); + store_transposed(b, 4, ptr, 0x20); + store_transposed(b, 7, ptr, 0x30); + END: +})", CONF); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + stv $v08, 0, 0, $t0 + stv $v08, 0, 0, $t0 + stv $v08, 2, 16, $t0 + stv $v16, 8, 32, $t0 + stv $v16, 14, 48, $t0 + END: + jr $ra + nop)"); +} + +TEST_CASE("Store - Invalid Transposed reg", "[store]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(function test() { + u32<$t0> ptr; + vec32<$v04> v; + store_transposed(v, 0, ptr, 0x00); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(function test() { u32<$t0> ptr; vec32<$v04> v; store_transposed(v, 0, ptr, 0x00); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("store_transposed() requires target register to be $v00, $v08, $v16 or $v24") != std::string::npos); + } +} + +TEST_CASE("Store - Invalid Transposed offset", "[store]") { + REQUIRE_THROWS_AS(rspl::transpileSource(R"(function test() { + u32<$t0> ptr; + vec32<$v16> v; + store_transposed(v, 0, ptr, 0x04); +})", CONF), std::runtime_error); + try { rspl::transpileSource(R"(function test() { u32<$t0> ptr; vec32<$v16> v; store_transposed(v, 0, ptr, 0x04); })", CONF); } + catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("store_transposed() requires offset to be multiple of 16") != std::string::npos); + } +} diff --git a/cpp/tests/test_swizzle.cpp b/cpp/tests/test_swizzle.cpp new file mode 100644 index 0000000..08ec583 --- /dev/null +++ b/cpp/tests/test_swizzle.cpp @@ -0,0 +1,172 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Syntax - Swizzle - Assign single (vec32 <- vec32)", "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a, b; + a.x = b.X; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v03.e4 + vmov $v02.e0, $v04.e4 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Swizzle - Assign single (vec32 <- vec32, cast)", + "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a, b; + SINT: + a.x = b:sint.X; + a:sint.x = b:sint.X; + a:ufract.x = b:sint.X; + + UFRACT: + a.x = b:ufract.X; + a:sint.x = b:ufract.X; + a:ufract.x = b:ufract.X; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + SINT: + vmov $v01.e0, $v03.e4 + vmov $v02.e0, $v00.e4 + vmov $v01.e0, $v03.e4 + vmov $v02.e0, $v03.e4 + UFRACT: + vmov $v01.e0, $v00.e4 + vmov $v02.e0, $v04.e4 + vmov $v01.e0, $v04.e4 + vmov $v02.e0, $v04.e4 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Swizzle - Assign single (vec16 <- vec16)", "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a, b; + a.x = b.X; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v02.e4 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Swizzle - Assign single (vec32 <- vec16)", "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a; + vec16<$v03> b; + a.x = b.X; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v03.e4 + vmov $v02.e0, $v00.e4 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Swizzle - Assign single (vec16 <- vec32)", "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v02> b; + a.x = b.X; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v02.e4 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Swizzle - Invalid on Scalar (calc)", "[swizzle]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + a += a.x; + })", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + a += a.x; + })", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Swizzling not allowed for scalar operations") != std::string::npos); + } +} + +TEST_CASE("Syntax - Swizzle - Invalid on Scalar (assign)", "[swizzle]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + a = a.x; + })", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + u32<$t0> a; + a = a.x; + })", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Swizzling not allowed for scalar operations") != std::string::npos); + } +} + +TEST_CASE("Syntax - Swizzle - Alias (integer index)", "[swizzle]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + a.x = a.0; + a.1 = a.z; + + a += a.xxzzXXZZ; + a += a.00224466; + + a += a.wwwwWWWW; + a += a.33337777; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v01.e0 + vmov $v01.e1, $v01.e2 + vaddc $v01, $v01, $v01.q0 + vaddc $v01, $v01, $v01.q0 + vaddc $v01, $v01, $v01.h3 + vaddc $v01, $v01, $v01.h3 + jr $ra + nop)"); +} diff --git a/cpp/tests/test_syntaxExpansion.cpp b/cpp/tests/test_syntaxExpansion.cpp new file mode 100644 index 0000000..7a1a57f --- /dev/null +++ b/cpp/tests/test_syntaxExpansion.cpp @@ -0,0 +1,304 @@ +#include +#include "pipeline.h" + +#include + +// Helper: assert two RSPL sources produce identical ASM +static void assertSameAsm(const std::string &srcA, + const std::string &srcB) { + auto wrap = [](const std::string &s) { + return "function test() { " + s + " }"; + }; + auto resA = rspl::transpileSource(wrap(srcA), {.rspqWrapper = false}); + auto resB = rspl::transpileSource(wrap(srcB), {.rspqWrapper = false}); + REQUIRE(resA.warn.empty()); + REQUIRE(resB.warn.empty()); + // Both should contain basic function structure + REQUIRE(resA.asm_.find("test:") != std::string::npos); + REQUIRE(resB.asm_.find("test:") != std::string::npos); + REQUIRE(resA.asm_.find("jr $ra") != std::string::npos); + REQUIRE(resB.asm_.find("jr $ra") != std::string::npos); + REQUIRE(resA.asm_ == resB.asm_); +} + +TEST_CASE("Syntax - Expansion - Decl+Assign - Scalar", + "[syntaxExpansion]") { + assertSameAsm("u32 a; a = 1234;", "u32 a = 1234;"); +} + +TEST_CASE("Syntax - Expansion - Decl+Assign - Vector", + "[syntaxExpansion]") { + assertSameAsm("vec16 a; a = 4;", "vec16 a = 4;"); +} + +TEST_CASE("Syntax - Expansion - Decl+Calc - Scalar", + "[syntaxExpansion]") { + assertSameAsm("u32 a,b,c; c = a + b;", "u32 a,b; u32 c = a + b;"); +} + +TEST_CASE("Syntax - Expansion - Decl+Calc - Scalar+Const", + "[syntaxExpansion]") { + assertSameAsm("u32 a,b,c; c = a + 42;", + "u32 a,b; u32 c = a + 42;"); +} + +TEST_CASE("Syntax - Expansion - Decl+Calc - Vector", + "[syntaxExpansion]") { + assertSameAsm("vec16 a,b,c; c = a + b;", + "vec16 a,b; u32 c = a + b;"); +} + +TEST_CASE("Syntax - Expansion - Decl+Calc - Vector+Const", + "[syntaxExpansion]") { + assertSameAsm("vec16 a,b,c; c = a + 32;", + "vec16 a,b; u32 c = a + 32;"); +} + +TEST_CASE("Syntax - Expansion - Assign+Calc - Scalar", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b; + a = a + b; + a = a - b; + a = a | b; + a = a & b; + a = a ^ b; + a = a >> b; + a = a >>> b; + a = a << b;)", + R"(u32 a,b; + a += b; + a -= b; + a |= b; + a &= b; + a ^= b; + a >>= b; + a >>>= b; + a <<= b;)"); +} + +TEST_CASE("Syntax - Expansion - Assign+Calc - Scalar+Const", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b; + a = a + 2; + a = a - 2; + a = a | 2; + a = a & 2; + a = a ^ 2; + a = a >> 2; + a = a >>> 2; + a = a << 2;)", + R"(u32 a,b; + a += 2; + a -= 2; + a |= 2; + a &= 2; + a ^= 2; + a >>= 2; + a >>>= 2; + a <<= 2;)"); +} + +TEST_CASE("Syntax - Expansion - Assign+Calc - Vector", + "[syntaxExpansion]") { + assertSameAsm( + R"(vec16 a,b; + a = a + b; + a = a - b; + a = a * b; + a = a | b; + a = a & b; + a = a ^ b;)", + R"(vec16 a,b; + a += b; + a -= b; + a *= b; + a = a | b; + a = a & b; + a = a ^ b;)"); +} + +TEST_CASE("Syntax - Expansion - Assign+Calc - Vector+Const", + "[syntaxExpansion]") { + assertSameAsm( + R"(vec16 a,b; + a = a + 2; + a = a - 2; + a = a * 2; + a = a >> 2; + a = a >>> 2; + a = a << 2;)", + R"(vec16 a,b; + a += 2; + a -= 2; + a *= 2; + a >>= 2; + a >>>= 2; + a <<= 2;)"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 0 - Scalar", + "[syntaxExpansion]") { + assertSameAsm("u32 a,b,c,d; a = b + c; a = a + d;", + "u32 a,b,c,d; a = b + c + d;"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 1 - Scalar", + "[syntaxExpansion]") { + assertSameAsm( + "u32 a,b,c,d; a = b + c; a = a + 4; a = a + d;", + "u32 a,b,c,d; a = b + c + 4 + d;"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 2 - Scalar", + "[syntaxExpansion]") { + assertSameAsm("u32 a,b,c,d; a = b + 8; a = b - 20;", + "u32 a,b,c,d; a = b + 4 + 4; a = b - (2 + 2 * 9);"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 3 - Scalar", + "[syntaxExpansion]") { + assertSameAsm("u32 a,b,c,d; a = b + 50;", + "u32 a,b,c,d; a = b + ((3 + 2) * 10);"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 4 - Scalar", + "[syntaxExpansion]") { + assertSameAsm( + "u32 a,b,c,d; u32 tmp = b + c; a = tmp >> d;", + "u32 a,b,c,d; a = (b + c) >> d;"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 5 - Scalar", + "[syntaxExpansion]") { + assertSameAsm( + "u32 a,b,c,d; u32 tmp0 = b + c; u32 tmp1 = d - a; a = tmp0 >> tmp1;", + "u32 a,b,c,d; a = (b + c) >> (d - a);"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 5.1 - Scalar", + "[syntaxExpansion]") { + assertSameAsm( + "u32 a,b,c,d; u32 tmp0 = c + d; a = a + tmp0;", + "u32 a,b,c,d; a = a + (c + d);"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 6 - Vector deeply nested", + "[syntaxExpansion]") { + assertSameAsm( + R"(vec16<$v02> a,b,c,d; + vec16 tmp0, tmp1, tmp2, tmp3, tmp4; + tmp1 = b * c; + tmp0 = a + tmp1; + tmp4 = a + 4; + tmp3 = tmp4 * c; + tmp2 = d - tmp3; + a = tmp0 - tmp2;)", + R"(vec16<$v02> a,b,c,d; + a = (a + b * c) - (d - (a + (2+2)) * c);)"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc 7 - Scalar Increment", + "[syntaxExpansion]") { + assertSameAsm( + "u32 a,b,c,d; u32 tmp0 = b + c; a = a + tmp0;", + "u32 a,b,c,d; a += b + c;"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc - Vector + Cast", + "[syntaxExpansion]") { + assertSameAsm( + R"(vec16 a,b,c,d; + vec16 tmp = b + c; + a = tmp:sfract * d:sfract;)", + R"(vec16 a,b,c,d; + a = (b + c) * d:sfract;)"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc - Scalar Const Op-Test", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b,c,d; + ARITH: + a = b + 5; + a = b + 6; + a = b + -1; + a = b + 3; + SHIFT: + a = b + 8; + a = b + 4; + a = b + -4; + a = b + -4; + a = b + 1073741820; + LOGIC: + a = b + 0b1111; + a = b + 0b1110; + a = b + 0b110011;)", + R"(u32 a,b,c,d; + ARITH: + a = b + (2 + 3); + a = b + (2 * 3); + a = b + (2 - 3); + a = b + (10 / 3); + SHIFT: + a = b + (1 << 3); + a = b + (16 >> 2); + a = b + (-2 << 1); + a = b + (-16 >> 2); + a = b + (-16 >>> 2); + LOGIC: + a = b + (0b0101 | 0b1010); + a = b + (0b1111 & 0b1110); + a = b + (0b010000 ^ 0b100011);)"); +} + +TEST_CASE( + "Syntax - Expansion - Multi+Calc - Scalar Const Order or Operations", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b,c,d; + a = b + 10; + a = b + 1; + a = b + 4; + a = b + 0x14;)", + R"(u32 a,b,c,d; + a = b + (1 + 1 * 9); + a = b + (10 - 1 * 9); + a = b + (1 + 1 << 1); + a = b + (1 + 1 << 1 | 0x10);)"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc - Scalar Const Brackets", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b,c,d; + a = b + 10; + a = b + 20; + a = b + 10; + a = b + 31;)", + R"(u32 a,b,c,d; + a = b + ((5) + (5)); + a = b + ((10) + (((5*2)))); + a = b + ((((((((((((((((10)))))))))))))))); + a = b + (3 * 10) + 1;)"); +} + +TEST_CASE("Syntax - Expansion - Multi+Calc - Scalar Only", + "[syntaxExpansion]") { + assertSameAsm( + R"(u32 a,b,c,d; + a = 10; + a = 20; + a = 30; + a = 40; + a = 100; + a = 200;)", + R"(u32 a,b,c,d; + a = 5 + 5; + a = 40 / 2; + a = 35 - 5; + a = 20 * 2; + a = (10 * 5) * 2; + a = 2 * (1 + 9 * 11);)"); +} diff --git a/cpp/tests/test_syntaxNumbers.cpp b/cpp/tests/test_syntaxNumbers.cpp new file mode 100644 index 0000000..0dbcf20 --- /dev/null +++ b/cpp/tests/test_syntaxNumbers.cpp @@ -0,0 +1,41 @@ +#include +#include "pipeline.h" + +TEST_CASE("Syntax - Numbers - Scalar Assignment", "[syntaxNumbers]") { + auto result = rspl::transpileSource( + R"(function test() +{ + u32<$t0> a; + a = 1234; + a = 0x1234; + a = 0b1010; +})", + {.rspqWrapper = false}); + + REQUIRE(result.asm_ == R"(test: + addiu $t0, $zero, 1234 + addiu $t0, $zero, 4660 + addiu $t0, $zero, 10 + jr $ra + nop)"); +} + +TEST_CASE("Syntax - Numbers - Scalar Calc", "[syntaxNumbers]") { + auto result = rspl::transpileSource( + R"(function test() +{ + u32<$t0> a; + a = a + 1234; + a = a + 0x1234; + a = a + 0b1010; +})", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $t0, $t0, 1234 + addiu $t0, $t0, 4660 + addiu $t0, $t0, 10 + jr $ra + nop)"); +} diff --git a/cpp/tests/test_syntaxVar.cpp b/cpp/tests/test_syntaxVar.cpp new file mode 100644 index 0000000..44cd152 --- /dev/null +++ b/cpp/tests/test_syntaxVar.cpp @@ -0,0 +1,84 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Syntax - Vars - Invalid type (scalar reg for vec)", "[syntaxVar]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + u32<$v03> a; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + u32<$v03> a; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Cannot use vector register for scalar variable") != std::string::npos); + } +} + +TEST_CASE("Syntax - Vars - Invalid type (vec reg for scalar)", "[syntaxVar]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + vec16<$t0> a; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + vec16<$t0> a; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Cannot use scalar register for vector variable") != std::string::npos); + } +} + +TEST_CASE("Syntax - Vars - Invalid (swizzle in decl)", "[syntaxVar]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + vec16<$v03> a.x; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + vec16<$v03> a.x; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find("Syntax error") != + std::string::npos); + } +} + +TEST_CASE("Syntax - Vars - Invalid (cast in decl)", "[syntaxVar]") { + REQUIRE_THROWS_AS( + rspl::transpileSource( + R"(function test() { + vec16<$v03> a:sint; +})", + {.rspqWrapper = false}), + std::runtime_error); + try { + rspl::transpileSource( + R"(function test() { + vec16<$v03> a:sint; +})", + {.rspqWrapper = false}); + } catch (const std::runtime_error &e) { + REQUIRE(std::string(e.what()).find( + "Variable name cannot contain a cast") != std::string::npos); + } +} diff --git a/cpp/tests/test_vectorOps.cpp b/cpp/tests/test_vectorOps.cpp new file mode 100644 index 0000000..6d33ed6 --- /dev/null +++ b/cpp/tests/test_vectorOps.cpp @@ -0,0 +1,956 @@ +#include +#include "pipeline.h" + +#include + +TEST_CASE("Vector - Ops - Assign (vec32 vs vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res = a; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v01, $v00, $v03 + vor $v02, $v00, $v04 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (vec16 vs vec32:cast)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res; + vec32<$v03> a; + res = a:uint; + res = a:sint; + res:ufract = a:ufract; + res:sfract = a:sfract; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v01, $v00, $v03 + vor $v01, $v00, $v03 + vor $v01, $v00, $v04 + vor $v01, $v00, $v04 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (vec16 vs vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res = a; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v01, $v00, $v02 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (vec16 broadcast)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res = a.yyyyYYYY; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v01, $v00, $v02.h1 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (vec32 broadcast)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res = a.yyyyYYYY; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v01, $v00, $v03.h1 + vor $v02, $v00, $v04.h1 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (swizzle, 2^x)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v02> b; + a.x = 2; + b.x = 8; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v30.e6 + vmov $v02.e0, $v30.e4 + vmov $v03.e0, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (swizzle, float)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v02> b; + a.x = 10.25; + b.x = 42.125; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + addiu $at, $zero, 10 + mtc2 $at, $v01.e0 + addiu $at, $zero, 42 + mtc2 $at, $v02.e0 + addiu $at, $zero, 8192 + mtc2 $at, $v03.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (swizzle, int-variable)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 s; + vec16<$v01> a; + vec32<$v02> b; + a.y = s; + b.z = s; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mtc2 $t0, $v01.e1 + mtc2 $t0, $v03.e2 + srl $at, $t0, 16 + mtc2 $at, $v02.e2 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (no-swizzle, int-variable)", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + u32 s; + vec16<$v01> a; + vec32<$v02> b; + a = s; + b = s; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + mtc2 $t0, $v01.e0 + vor $v01, $v00, $v01.e0 + mtc2 $t0, $v03.e0 + srl $at, $t0, 16 + mtc2 $at, $v02.e0 + vor $v02, $v00, $v02.e0 + vor $v03, $v00, $v03.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (swizzle, 0)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v02> b; + a.x = 0; + b.x = 0; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v00.e0 + vmov $v02.e0, $v00.e0 + vmov $v03.e0, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (cast, swizzle, 0)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v02> b; + a:sint.x = 0; + b:sfract.x = 0; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmov $v01.e0, $v00.e0 + vmov $v03.e0, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Assign (0)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a = 0; + vec32<$v02> b = 0; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vxor $v01, $v00, $v00.e0 + vxor $v02, $v00, $v00.e0 + vxor $v03, $v00, $v00 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Add (vec32 vs vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res += a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vaddc $v02, $v02, $v04.e0 + vadd $v01, $v01, $v03.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Add (vec16 vs vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res += a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vaddc $v01, $v01, $v02.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Add (vec16 cast)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res:uint += a.x; + res:sint += a.x; + res:sfract += a.x; + res:ufract += a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vaddc $v01, $v01, $v02.e0 + vadd $v01, $v01, $v02.e0 + vadd $v01, $v01, $v00.e0 + vaddc $v01, $v01, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Sub (vec32 vs vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res -= a.y; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vsubc $v02, $v02, $v04.e1 + vsub $v01, $v01, $v03.e1 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Sub (vec16 vs vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res -= a; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vsubc $v01, $v01, $v02.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Mul (vec32 vs vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res *= a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v29, $v02, $v04.e0 + vmadm $v29, $v01, $v04.e0 + vmadn $v02, $v02, $v03.e0 + vmadh $v01, $v01, $v03.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Mul (vec16 vs vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res *= a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudn $v01, $v01, $v02.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Mul (vec16 cast)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res:uint *= a.x; + res:sint *= a.x; + res:ufract *= a.x; + res:sfract *= a.x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudn $v01, $v01, $v02.e0 + vmudh $v01, $v01, $v02.e0 + vmulu $v01, $v01, $v02.e0 + vmulf $v01, $v01, $v02.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - AND (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> res16, a16; + vec32<$v04> a32; + + res16 = a16 & a16; + res16 = a32 & a16; + res16 = a16 & a32; + res16 = a32 & a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vand $v02, $v03, $v03.v + vand $v02, $v04, $v03.v + vand $v02, $v03, $v04.v + vand $v02, $v04, $v04.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - AND (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> res32, a32; + vec16<$v06> a16; + + res32 = a16 & a16; A: + res32 = a32 & a16; B: + res32 = a16 & a32; C: + res32 = a32 & a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vand $v02, $v06, $v06.v + vand $v03, $v00, $v00.v + A: + vand $v02, $v04, $v06.v + vand $v03, $v05, $v00.v + B: + vand $v02, $v06, $v04.v + vand $v03, $v00, $v05.v + C: + vand $v02, $v04, $v04.v + vand $v03, $v05, $v05.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - OR (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> res16, a16; + vec32<$v04> a32; + + res16 = a16 | a16; + res16 = a32 | a16; + res16 = a16 | a32; + res16 = a32 | a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v02, $v03, $v03.v + vor $v02, $v04, $v03.v + vor $v02, $v03, $v04.v + vor $v02, $v04, $v04.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - OR (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> res32, a32; + vec16<$v06> a16; + + res32 = a16 | a16; AA: + res32 = a32 | a16; BB: + res32 = a16 | a32; CC: + res32 = a32 | a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vor $v02, $v06, $v06.v + vor $v03, $v00, $v00.v + AA: + vor $v02, $v04, $v06.v + vor $v03, $v05, $v00.v + BB: + vor $v02, $v06, $v04.v + vor $v03, $v00, $v05.v + CC: + vor $v02, $v04, $v04.v + vor $v03, $v05, $v05.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - XOR (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> res16, a16; + vec32<$v04> a32; + + res16 = a16 ^ a16; + res16 = a32 ^ a16; + res16 = a16 ^ a32; + res16 = a32 ^ a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vxor $v02, $v03, $v03.v + vxor $v02, $v04, $v03.v + vxor $v02, $v03, $v04.v + vxor $v02, $v04, $v04.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - XOR (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> res32, a32; + vec16<$v06> a16; + + res32 = a16 ^ a16; A: + res32 = a32 ^ a16; B: + res32 = a16 ^ a32; C: + res32 = a32 ^ a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vxor $v02, $v06, $v06.v + vxor $v03, $v00, $v00.v + A: + vxor $v02, $v04, $v06.v + vxor $v03, $v05, $v00.v + B: + vxor $v02, $v06, $v04.v + vxor $v03, $v00, $v05.v + C: + vxor $v02, $v04, $v04.v + vxor $v03, $v05, $v05.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - NOT (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> res16, a16; + vec32<$v04> a32; + + res16 = ~a16; + res16 = ~a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vnor $v02, $v03, $v00.v + vnor $v02, $v04, $v00.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - NOT (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> res32, a32; + vec16<$v06> a16; + + res32 = ~a16; + res32 = ~a32; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vnor $v02, $v06, $v00.v + vnor $v03, $v00, $v00.v + vnor $v02, $v04, $v00.v + vnor $v03, $v05, $v00.v + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Invert-Half (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + a.x = invert_half(a).x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vrcph $v03.e0, $v03.e0 + vrcpl $v04.e0, $v04.e0 + vrcph $v03.e0, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Invert-Half - all (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + a = invert_half(a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vrcph $v03.e0, $v03.e0 + vrcpl $v04.e0, $v04.e0 + vrcph $v03.e0, $v00.e0 + vrcph $v03.e1, $v03.e1 + vrcpl $v04.e1, $v04.e1 + vrcph $v03.e1, $v00.e1 + vrcph $v03.e2, $v03.e2 + vrcpl $v04.e2, $v04.e2 + vrcph $v03.e2, $v00.e2 + vrcph $v03.e3, $v03.e3 + vrcpl $v04.e3, $v04.e3 + vrcph $v03.e3, $v00.e3 + vrcph $v03.e4, $v03.e4 + vrcpl $v04.e4, $v04.e4 + vrcph $v03.e4, $v00.e4 + vrcph $v03.e5, $v03.e5 + vrcpl $v04.e5, $v04.e5 + vrcph $v03.e5, $v00.e5 + vrcph $v03.e6, $v03.e6 + vrcpl $v04.e6, $v04.e6 + vrcph $v03.e6, $v00.e6 + vrcph $v03.e7, $v03.e7 + vrcpl $v04.e7, $v04.e7 + vrcph $v03.e7, $v00.e7 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Invert-SQRT-Half (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + a.x = invert_half_sqrt(a).x; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vrsqh $v03.e0, $v03.e0 + vrsql $v04.e0, $v04.e0 + vrsqh $v03.e0, $v00.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Invert (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + a = invert(a); + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vrcph $v03.e0, $v03.e0 + vrcpl $v04.e0, $v04.e0 + vrcph $v03.e0, $v00.e0 + vrcph $v03.e1, $v03.e1 + vrcpl $v04.e1, $v04.e1 + vrcph $v03.e1, $v00.e1 + vrcph $v03.e2, $v03.e2 + vrcpl $v04.e2, $v04.e2 + vrcph $v03.e2, $v00.e2 + vrcph $v03.e3, $v03.e3 + vrcpl $v04.e3, $v04.e3 + vrcph $v03.e3, $v00.e3 + vrcph $v03.e4, $v03.e4 + vrcpl $v04.e4, $v04.e4 + vrcph $v03.e4, $v00.e4 + vrcph $v03.e5, $v03.e5 + vrcpl $v04.e5, $v04.e5 + vrcph $v03.e5, $v00.e5 + vrcph $v03.e6, $v03.e6 + vrcpl $v04.e6, $v04.e6 + vrcph $v03.e6, $v00.e6 + vrcph $v03.e7, $v03.e7 + vrcpl $v04.e7, $v04.e7 + vrcph $v03.e7, $v00.e7 + vmudn $v04, $v04, $v30.e6 + vmadh $v03, $v03, $v30.e6 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Left (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> a, b; + b = a << 1; + b = a << 4; + b = a << 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudn $v03, $v02, $v30.e6 + vmudn $v03, $v02, $v30.e3 + vmudn $v03, $v02, $v31.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Right Arithmetic (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> a, b; + b = a >> 1; + b = a >> 4; + b = a >> 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudm $v03, $v02, $v31.e0 + vmudm $v03, $v02, $v31.e3 + vmudm $v03, $v02, $v30.e6 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Right Logical (vec16)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v02> a, b; + b = a >>> 1; + b = a >>> 4; + b = a >>> 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v03, $v02, $v31.e0 + vmudl $v03, $v02, $v31.e3 + vmudl $v03, $v02, $v30.e6 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Left (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> a, b; + b = a << 1; + b = a << 4; + b = a << 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v04, $v03, $v30.e6 + vmadn $v04, $v02, $v30.e6 + vmudn $v05, $v03, $v30.e6 + vmudl $v04, $v03, $v30.e3 + vmadn $v04, $v02, $v30.e3 + vmudn $v05, $v03, $v30.e3 + vmudl $v04, $v03, $v31.e0 + vmadn $v04, $v02, $v31.e0 + vmudn $v05, $v03, $v31.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Left (vec32 self-assign)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> a, b; + a = a << 1; + a = a << 4; + a = a << 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v29, $v03, $v30.e6 + vmadn $v02, $v02, $v30.e6 + vmudn $v03, $v03, $v30.e6 + vmudl $v29, $v03, $v30.e3 + vmadn $v02, $v02, $v30.e3 + vmudn $v03, $v03, $v30.e3 + vmudl $v29, $v03, $v31.e0 + vmadn $v02, $v02, $v31.e0 + vmudn $v03, $v03, $v31.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift Left (vec16 = vec32 << X)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> b; + vec16<$v04> a; + a = b << 1; + a = b << 4; + a = b << 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v04, $v03, $v30.e6 + vmadn $v04, $v02, $v30.e6 + vmudl $v04, $v03, $v30.e3 + vmadn $v04, $v02, $v30.e3 + vmudl $v04, $v03, $v31.e0 + vmadn $v04, $v02, $v31.e0 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift right Arithmetic (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> a, b; + b = a >> 1; + b = a >> 4; + b = a >> 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v05, $v03, $v31.e0 + vmadm $v04, $v02, $v31.e0 + vmadn $v05, $v00, $v00 + vmudl $v05, $v03, $v31.e3 + vmadm $v04, $v02, $v31.e3 + vmadn $v05, $v00, $v00 + vmudl $v05, $v03, $v30.e6 + vmadm $v04, $v02, $v30.e6 + vmadn $v05, $v00, $v00 + jr $ra + nop)"); +} + +TEST_CASE("Vector - Ops - Shift right Logical (vec32)", "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v02> a, b; + b = a >>> 1; + b = a >>> 4; + b = a >>> 15; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudl $v05, $v03, $v31.e0 + vmadn $v04, $v02, $v31.e0 + vmadn $v05, $v00, $v00 + vmudl $v05, $v03, $v31.e3 + vmadn $v04, $v02, $v31.e3 + vmadn $v05, $v00, $v00 + vmudl $v05, $v03, $v30.e6 + vmadn $v04, $v02, $v30.e6 + vmadn $v05, $v00, $v00 + jr $ra + nop)"); +} + +TEST_CASE("VectorOps - Multiply-accumulate +* - vec32", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> a; + vec32<$v03> b; + vec32<$v05> res; + res = a +* b; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmadl $v29, $v02, $v04.v + vmadm $v29, $v01, $v04.v + vmadn $v06, $v02, $v03.v + vmadh $v05, $v01, $v03.v + jr $ra + nop)"); +} + +TEST_CASE("VectorOps - Multiply vec16 * vec32 -> vec32", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> a; + vec32<$v03> b; + vec32<$v05> res; + res = a * b; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + vmudm $v06, $v01, $v04.v + vmadh $v05, $v01, $v03.v + vmadn $v06, $v00, $v00 + jr $ra + nop)"); +} + +TEST_CASE("VectorOps - Half-move vec32 xyzw=XYZW (upper to lower)", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res.xyzw = a.XYZW; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ori $at, $zero, %lo(RSPQ_SCRATCH_MEM) + sdv $v03, 8, 0, $at + sdv $v04, 8, 8, $at + ldv $v01, 0, 0, $at + ldv $v02, 0, 8, $at + jr $ra + nop)"); +} + +TEST_CASE("VectorOps - Half-move vec32 XYZW=xyzw (lower to upper)", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec32<$v01> res, a; + res.XYZW = a.xyzw; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ori $at, $zero, %lo(RSPQ_SCRATCH_MEM) + sdv $v03, 0, 0, $at + sdv $v04, 0, 8, $at + ldv $v01, 8, 0, $at + ldv $v02, 8, 8, $at + jr $ra + nop)"); +} + +TEST_CASE("VectorOps - Half-move vec16 xyzw=XYZW (upper to lower)", + "[vectorOps]") { + auto result = rspl::transpileSource( + R"(function test() { + vec16<$v01> res, a; + res.xyzw = a.XYZW; + })", + {.rspqWrapper = false}); + + REQUIRE(result.warn.empty()); + REQUIRE(result.asm_ == R"(test: + ori $at, $zero, %lo(RSPQ_SCRATCH_MEM) + sdv $v02, 8, 0, $at + ldv $v01, 0, 0, $at + jr $ra + nop)"); +} diff --git a/scripts/parse.js b/scripts/parse.js new file mode 100644 index 0000000..c80752c --- /dev/null +++ b/scripts/parse.js @@ -0,0 +1,92 @@ +/** + * RSPL parser wrapper — reads .rspl source, outputs AST as JSON to stdout. + * + * Usage: + * node scripts/parse.js # full parse (preproc + parse) + * node scripts/parse.js --preprocessed # skip preprocessor + * + * This is the thin JS-side of the C++ port. It keeps the existing Nearley + * grammar and serializes the resulting AST as JSON. + */ + +import { readFileSync } from "fs"; +import * as path from "path"; +import nearly from "nearley"; +import grammarDef from "../src/lib/grammar.cjs"; +import { stripComments } from "../src/lib/preproc/preprocess.js"; +import { preprocess } from "../src/lib/preproc/preprocess.js"; + +const grammar = nearly.Grammar.fromCompiled(grammarDef); + +function fileLoader(filePath) { + const sourceBaseDir = + process.env.RSPL_SOURCE_DIR || path.dirname(args.inputFile || "."); + return readFileSync(path.join(sourceBaseDir, filePath), "utf8"); +} + +const args = { + inputFile: null, + skipPreproc: false, +}; + +for (let i = 2; i < process.argv.length; ++i) { + if (process.argv[i] === "--preprocessed") { + args.skipPreproc = true; + } else if (process.argv[i].startsWith("-")) { + // unknown flag — ignore + } else { + args.inputFile = process.argv[i]; + } +} + +function parse(source) { + const defines = {}; + if (args.skipPreproc) { + // Source already has comments stripped and defines expanded by C++ + source = stripComments(source); // strip comments again for safety + } else { + source = stripComments(source); + + if (process.env.RSPL_DEFINES) { + for (const def of process.env.RSPL_DEFINES.split(",")) { + const [key, value] = def.split("="); + if (key) { + defines[key] = { + regex: new RegExp(`\\b${key}\\b`, "g"), + value: value || "1", + }; + } + } + } + source = preprocess(source, defines, fileLoader); + } + + const parser = new nearly.Parser(grammar); + const astList = parser.feed(source); + + if (astList.results.length > 1) { + throw new Error("Warning: ambiguous syntax!"); + } + + const ast = astList.results[0]; + if (process.env.RSPL_KEEP_DEFINES && !args.skipPreproc) { + ast.defines = defines; + } else { + delete ast.defines; + } + + return ast; +} + +function main() { + if (!args.inputFile) { + console.error("Usage: node parse.js [--preprocessed] "); + process.exit(1); + } + + const source = readFileSync(args.inputFile, "utf8"); + const ast = parse(source); + process.stdout.write(JSON.stringify(ast, null, 2)); +} + +main();