Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 97 additions & 64 deletions lua/wildfire/init.lua
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
local api = vim.api

local keymap = vim.keymap
local ts_utils = require("nvim-treesitter.ts_utils")
local parsers = require("nvim-treesitter.parsers")
local utils = require("wildfire.utils")
local surround = require("wildfire.surround")

local M = {}

M.options = {
surrounds = {
{ "(", ")" },
{ "{", "}" },
{ "<", ">" },
{ "[", "]" },
},
keymaps = {
init_selection = "<CR>",
node_incremental = "<CR>",
Expand All @@ -27,53 +20,20 @@ local selections = {}
local nodes_all = {}
local count = 1

function M.unsurround_coordinates(node_or_range, buf)
-- local lines = vim.split(s, "\n")
local srow, scol, erow, ecol = utils.get_range(node_or_range)
local lines = vim.api.nvim_buf_get_text(buf, srow - 1, scol - 1, erow - 1, ecol, {})
local node_text = table.concat(lines, "\n")
local match_brackets = nil
for _, pair in ipairs(M.options.surrounds) do
local pattern = "^%" .. pair[1] .. ".*%" .. pair[2] .. "$"
match_brackets = string.match(node_text, pattern)
if match_brackets then
break
---Get inner coordinates for a surround node
---@param node TSNode
---@param buf integer? buffer number
---@return boolean is_surround
---@return table {srow, scol, erow, ecol} 1-based vim coordinates
function M.unsurround_coordinates(node, buf)
if surround.is_surround(node) then
local inner = surround.get_inner_range(node, buf)
if inner then
return true, inner
end
-- Empty content (e.g., () or {}), use node range
end
-- local match_brackets = string.match(node_text, "^%b{}$")
-- or string.match(node_text, "^%b()$")
-- or string.match(node_text, "^%b[]$")
if match_brackets == nil then
return false, { srow, scol, erow, ecol }
end
lines[1] = lines[1]:sub(2)
local nsrow, nscol = 0, 0
for index, line in ipairs(lines) do
if line:match("%S") then
nsrow = index
nscol = line:len() - line:match("^%s*(.*)"):len()
break
end
end

lines[#lines] = lines[#lines]:sub(1, -2)
local nerow, necol = #lines, 0
for index = #lines, 1, -1 do
local line = lines[index]
if line:match("%S") then
nerow = index
necol = line:len() - line:match("^(.*%S)%s*$"):len()
break
end
end

nsrow = srow + nsrow - 1
nscol = nsrow == srow and scol + nscol + 1 or nscol + 1
-- nerow = erow - nerow + 1
nerow = srow + nerow - 1
necol = nerow == erow and ecol - necol - 1 or lines[nerow - srow + 1]:len() - necol

return true, { nsrow, nscol, nerow, necol }
return false, { utils.get_range(node) }
end
local function update_selection_by_node(node)
local buf = api.nvim_get_current_buf()
Expand Down Expand Up @@ -102,8 +62,17 @@ local function init_by_node(node)
end
function M.init_selection()
count = vim.v.count1
local node = ts_utils.get_node_at_cursor()
local node = vim.treesitter.get_node({ ignore_injections = false })
if not node then
-- No treesitter node available, try to handle gracefully
-- Check if treesitter is available for this filetype
local buf = api.nvim_get_current_buf()
local ok, parser = pcall(vim.treesitter.get_parser, buf)
if not ok or not parser then
-- No parser available for this filetype
vim.notify("Wildfire: No treesitter parser available for this filetype", vim.log.levels.WARN)
return
end
return
end
init_by_node(node)
Expand Down Expand Up @@ -134,9 +103,21 @@ local function select_incremental(get_parent)

-- Initialize incremental selection with current selection
if not nodes or #nodes == 0 then
local root = parsers.get_parser():parse()[1]:root()
-- Use native vim.treesitter API to get the parser
local ok, parser = pcall(vim.treesitter.get_parser, buf)
if not ok or not parser then
-- No parser available for this filetype, fallback to visual selection
return
end
local tree = parser:parse()[1]
if not tree then
return
end
local root = tree:root()
local node = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
update_selection_by_node(node)
if node then
update_selection_by_node(node)
end
return
end

Expand All @@ -145,17 +126,49 @@ local function select_incremental(get_parent)
while true do
local parent = get_parent(node)
if not parent or parent == node then
-- Keep searching in the main tree
-- TODO: we should search on the parent tree of the current node.
local root = parsers.get_parser():parse()[1]:root()
parent = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
if not parent or root == node or parent == node then
-- Search in parent language tree for injected languages
local ok, parser = pcall(vim.treesitter.get_parser, buf)
if not ok or not parser then
utils.update_selection(buf, node)
return
end

local range = { csrow - 1, cscol - 1, cerow - 1, cecol }
local current_lang_tree = parser:language_for_range(range)
local parent_lang_tree = current_lang_tree and current_lang_tree:parent()

if parent_lang_tree then
local parent_tree = parent_lang_tree:tree_for_range(range)
if parent_tree then
local root = parent_tree:root()
local candidate = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
if candidate then
local csr, csc, cer, cec = utils.get_range(candidate)
local candidate_range = { csr, csc, cer, cec }
local current_range = { csrow, cscol, cerow, cecol }

if utils.range_larger(candidate_range, current_range)
and not utils.range_match(candidate_range, current_range) then
parent = candidate
else
utils.update_selection(buf, node)
return
end
else
utils.update_selection(buf, node)
return
end
else
utils.update_selection(buf, node)
return
end
else
utils.update_selection(buf, node)
return
end
end
node = parent
local nsrow, nscol, nerow, necol = ts_utils.get_vim_range({ node:range() })
local nsrow, nscol, nerow, necol = utils.get_range(node)

local larger_range = utils.range_larger({ nsrow, nscol, nerow, necol }, { csrow, cscol, cerow, cecol })

Expand Down Expand Up @@ -191,8 +204,28 @@ end
function M.visual_inner()
local buf = api.nvim_get_current_buf()
local csrow, cscol, cerow, cecol = utils.visual_selection_range()
local _, selection = M.unsurround_coordinates({ csrow, cscol, cerow, cecol }, buf)
utils.update_selection(buf, selection)

-- Get the treesitter node at the selection range
local ok, parser = pcall(vim.treesitter.get_parser, buf)
if not ok or not parser then
return
end

local tree = parser:parse()[1]
if not tree then
return
end

local root = tree:root()
local node = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol)
if not node then
return
end

local is_surround, selection = M.unsurround_coordinates(node, buf)
if is_surround then
utils.update_selection(buf, selection)
end
end

local FUNCTION_DESCRIPTIONS = {
Expand Down
77 changes: 77 additions & 0 deletions lua/wildfire/surround.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
local ts = vim.treesitter

local M = {}

---Check if node is a surround type by examining its children
---A surround node has unnamed (anonymous) nodes as first and last children
---These unnamed nodes are typically delimiters like (, ), {, }, ", ', etc.
---@param node TSNode
---@return boolean
function M.is_surround(node)
local child_count = node:child_count()
if child_count < 2 then
return false
end

local first = node:child(0)
local last = node:child(child_count - 1)

-- Anonymous nodes are delimiters
return first and last and not first:named() and not last:named()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Limit surround heuristic to bracket delimiters

This heuristic classifies any node with anonymous first/last children as a surround, which incorrectly matches many non-delimiter constructs (for example statement nodes like return x; in JS grammars). In those cases unsurround_coordinates treats keywords/punctuation as wrappers and can introduce unintended intermediate selections, changing expansion behavior beyond bracket-like surrounds that the plugin previously targeted.

Useful? React with 👍 / 👎.

end

---Get inner range (excluding delimiter nodes)
---@param node TSNode
---@param buf? integer buffer number (needed for edge case handling)
---@return table|nil {srow, scol, erow, ecol} 1-based vim coordinates
function M.get_inner_range(node, buf)
local child_count = node:child_count()
if child_count < 2 then
return nil
end

local first = node:child(0)
local last = node:child(child_count - 1)

if not first or not last or first:named() or last:named() then
return nil
end

-- Get the end of first delimiter and start of last delimiter
local first_end_row, first_end_col = first:end_()
local last_start_row, last_start_col = last:start()

-- Convert to 1-based vim coordinates
local srow = first_end_row + 1
local scol = first_end_col + 1
local erow = last_start_row + 1
local ecol = last_start_col -- 0-based exclusive == 1-based inclusive

buf = buf or vim.api.nvim_get_current_buf()

-- Handle edge case: when first delimiter ends at line end,
-- content starts at next line's beginning
local first_line = vim.api.nvim_buf_get_lines(buf, srow - 1, srow, false)[1]
if first_line and scol > #first_line then
srow = srow + 1
scol = 1
end

-- Handle edge case: when last delimiter is at line start (col 0),
-- content ends at previous line's end
if ecol < 1 and erow > srow then
erow = erow - 1
local line = vim.api.nvim_buf_get_lines(buf, erow - 1, erow, false)[1]
ecol = line and #line or 1
end

-- Handle empty content: when delimiters are adjacent (e.g., () or {})
-- Return nil to indicate no inner content
if srow > erow or (srow == erow and scol > ecol) then
return nil
end

return { srow, scol, erow, ecol }
end

return M
57 changes: 46 additions & 11 deletions lua/wildfire/utils.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
local api = vim.api

local ts_utils = require("nvim-treesitter.ts_utils")
local ts = vim.treesitter

local M = {}
Expand All @@ -9,8 +8,14 @@ function M.get_range(node_or_range)
if type(node_or_range) == "table" then
start_row, start_col, end_row, end_col = unpack(node_or_range)
else
local buf = api.nvim_get_current_buf()
start_row, start_col, end_row, end_col = ts_utils.get_vim_range({ ts.get_node_range(node_or_range) }, buf)
start_row, start_col, end_row, end_col = ts.get_node_range(node_or_range)
-- Convert 0-based to 1-based indexing to match vim coordinates
-- Note: treesitter end_col is exclusive, but vim coordinates are inclusive
start_row = start_row + 1
start_col = start_col + 1
end_row = end_row + 1
-- end_col is already exclusive in treesitter (0-based), converting to 1-based inclusive means no change needed
-- because: 0-based exclusive position == 1-based inclusive position
end
return start_row, start_col, end_row, end_col ---@type integer, integer, integer, integer
end
Expand Down Expand Up @@ -63,15 +68,15 @@ end

function M.print_selection(node_or_range)
local bufnr = api.nvim_get_current_buf()
local lines
local node_text
if type(node_or_range) == "table" then
local srow, scol, erow, ecol
srow, scol, erow, ecol = unpack(node_or_range)
lines = vim.api.nvim_buf_get_text(bufnr, srow - 1, scol - 1, erow - 1, ecol, {})
local lines = vim.api.nvim_buf_get_text(bufnr, srow - 1, scol - 1, erow - 1, ecol, {})
node_text = table.concat(lines, "\n")
else
lines = ts_utils.get_node_text(node_or_range, bufnr)
node_text = vim.treesitter.get_node_text(node_or_range, bufnr)
end
local node_text = table.concat(lines, "\n")
print(node_text)
end

Expand All @@ -81,14 +86,44 @@ function M.update_selection(buf, node_or_range, selection_mode)
if type(node_or_range) == "table" then
start_row, start_col, end_row, end_col = unpack(node_or_range)
else
start_row, start_col, end_row, end_col = ts_utils.get_vim_range({ ts.get_node_range(node_or_range) }, buf)
start_row, start_col, end_row, end_col = ts.get_node_range(node_or_range)
-- Convert 0-based to 1-based indexing to match vim coordinates
-- Note: treesitter end_col is exclusive, but vim coordinates are inclusive
start_row = start_row + 1
start_col = start_col + 1
end_row = end_row + 1
-- end_col is already exclusive in treesitter (0-based), converting to 1-based inclusive means no change needed
-- because: 0-based exclusive position == 1-based inclusive position
end

-- Validate buffer bounds to prevent cursor position errors
local line_count = api.nvim_buf_line_count(buf)
if start_row < 1 or end_row < 1 or start_row > line_count or end_row > line_count then
-- Invalid row range, cannot update selection
Comment on lines +101 to +102
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow selecting nodes that end at EOF

This guard rejects valid Tree-sitter ranges at file end because ranges are end-exclusive before conversion. For nodes that end at EOF (common when the buffer has a trailing newline), end_row becomes line_count + 1 after +1, so update_selection returns early and incremental expansion silently stops before the outer node can be selected; meanwhile callers still push that node into history, which can desynchronize selection state from what is visible.

Useful? React with 👍 / 👎.

return
end

-- Validate column bounds
if start_col < 1 or end_col < 0 then
-- Invalid column range, cannot update selection
return
end

-- Get the actual line lengths to validate column positions
local start_line_text = api.nvim_buf_get_lines(buf, start_row - 1, start_row, false)[1] or ""
local end_line_text = api.nvim_buf_get_lines(buf, end_row - 1, end_row, false)[1] or ""
local start_line_len = #start_line_text
local end_line_len = #end_line_text

-- Clamp column positions to valid ranges (0-indexed for nvim_win_set_cursor)
start_col = math.max(0, math.min(start_col - 1, start_line_len))
end_col = math.max(0, math.min(end_col - 1, end_line_len))

local v_table = { charwise = "v", linewise = "V", blockwise = "<C-v>" }
selection_mode = selection_mode or "charwise"

-- Normalise selection_mode
if vim.tbl_contains(vim.tbl_keys(v_table), selection_mode) then
if v_table[selection_mode] then
selection_mode = v_table[selection_mode]
end

Expand All @@ -103,8 +138,8 @@ function M.update_selection(buf, node_or_range, selection_mode)
api.nvim_cmd({ cmd = "normal", bang = true, args = { selection_mode } }, {})
end

api.nvim_win_set_cursor(0, { start_row, start_col - 1 })
api.nvim_win_set_cursor(0, { start_row, start_col })
vim.cmd("normal! o")
api.nvim_win_set_cursor(0, { end_row, end_col - 1 })
api.nvim_win_set_cursor(0, { end_row, end_col })
end
return M
Loading