Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions knockout.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ for i = 3, #arg do
end

states.setMid((TARGET_MAX + TARGET_MIN) / 2)
states.setBounds(TARGET_MIN, TARGET_MAX)

-- initialize min heap
local heap = require("min")
Expand Down
59 changes: 51 additions & 8 deletions min.lua
Original file line number Diff line number Diff line change
@@ -1,16 +1,59 @@
local heap = {
data = {}
}
local heap = { data = {} }

local function priority(state)
return state.score + state.depth
end

local function siftUp(data, idx)
while idx > 1 do
local parent = math.floor(idx / 2)
if priority(data[parent]) <= priority(data[idx]) then
break
end
data[parent], data[idx] = data[idx], data[parent]
idx = parent
end
end

local function siftDown(data, idx)
local size = #data
while true do
local left = idx * 2
local right = left + 1
local smallest = idx

if left <= size and priority(data[left]) < priority(data[smallest]) then
smallest = left
end
if right <= size and priority(data[right]) < priority(data[smallest]) then
smallest = right
end
if smallest == idx then
break
end
data[idx], data[smallest] = data[smallest], data[idx]
idx = smallest
end
end

function heap.push(state)
heap.data[#heap.data + 1] = state
table.sort(heap.data, function(a, b)
return (a.score + a.depth) < (b.score + b.depth)
end)
local data = heap.data
data[#data + 1] = state
siftUp(data, #data)
end

function heap.pop()
return table.remove(heap.data, 1)
local data = heap.data
if #data == 0 then
return nil
end
local root = data[1]
local last = table.remove(data)
if #data > 0 then
data[1] = last
siftDown(data, 1)
end
return root
end

function heap.isEmpty()
Expand Down
79 changes: 60 additions & 19 deletions states.lua
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
local M = {}

M.mid = nil
M.minBound = nil
M.maxBound = nil

function M.setMid(mid)
M.mid = mid
end

function M.setBounds(minBound, maxBound)
M.minBound = minBound
M.maxBound = maxBound
end

local function isInteger(value)
return value ~= nil and math.tointeger(value) ~= nil
end

local function safePow(base, exp)
if not isInteger(exp) then
return nil
end
exp = math.tointeger(exp)
if exp < 0 then
return nil
end
if base == 0 and exp == 0 then
return nil
end
return base ^ exp
end

-- operations
local ops = {
["+"] = function(a,b) return a + b end,
["a-b"] = function(a,b) return a - b end,
["b-a"] = function(a,b) return b - a end,
["*"] = function(a,b) return a * b end,
["a/b"] = function(a,b) return (b ~= 0) and a / b or nil end,
["b/a"] = function(a,b) return (a ~= 0) and b / a or nil end
["+"] = function(a, b) return a + b end,
["a-b"] = function(a, b) return a - b end,
["b-a"] = function(a, b) return b - a end,
["*"] = function(a, b) return a * b end,
["a/b"] = function(a, b) return (b ~= 0 and a % b == 0) and a / b or nil end,
["b/a"] = function(a, b) return (a ~= 0 and b % a == 0) and b / a or nil end,
["a^b"] = function(a, b) return safePow(a, b) end,
["b^a"] = function(a, b) return safePow(b, a) end
}

-- given numbers a, b, return all binary ops
function M.calc(a, b)
local results = {}
local seen = {}
for op, func in pairs(ops) do
results[#results+1] = func(a, b)
local result = func(a, b)
if isInteger(result) then
result = math.tointeger(result)
if not seen[result] then
seen[result] = true
results[#results + 1] = result
end
end
end
return results
end
Expand Down Expand Up @@ -68,18 +103,16 @@ function M.searchNextDepth(state)
local a, b = nums[i], nums[j]

for _, r in ipairs(M.calc(a, b)) do
if math.tointeger(r) then
local nextNums = {}
local nextNums = {}

for k = 1, n do
if k ~= i and k ~= j then
nextNums[#nextNums + 1] = nums[k]
end
for k = 1, n do
if k ~= i and k ~= j then
nextNums[#nextNums + 1] = nums[k]
end

nextNums[#nextNums + 1] = r
results[#results + 1] = M.newState(nextNums, state)
end

nextNums[#nextNums + 1] = r
results[#results + 1] = M.newState(nextNums, state)
end
end
end
Expand All @@ -88,11 +121,19 @@ function M.searchNextDepth(state)
end

function M.canReach(state, min, max)
local sum = 0
if min == nil or max == nil then
return true
end
if min > max then
return false
end
local sumAbs = 0
for _, v in ipairs(state.raw) do
sum = sum + math.abs(v)
sumAbs = sumAbs + math.abs(v)
end
return sum >= min and sum <= max
local lower = -sumAbs
local upper = sumAbs
return not (max < lower or min > upper)
end

return M
Loading