diff --git a/lua/wildfire/init.lua b/lua/wildfire/init.lua index ec0628f..ac76da2 100644 --- a/lua/wildfire/init.lua +++ b/lua/wildfire/init.lua @@ -1,19 +1,12 @@ local api = vim.api local keymap = vim.keymap -local ts_utils = require("nvim-treesitter.ts_utils") -local parsers = require("nvim-treesitter.parsers") local utils = require("wildfire.utils") +local surround = require("wildfire.surround") local M = {} M.options = { - surrounds = { - { "(", ")" }, - { "{", "}" }, - { "<", ">" }, - { "[", "]" }, - }, keymaps = { init_selection = "", node_incremental = "", @@ -27,53 +20,20 @@ local selections = {} local nodes_all = {} local count = 1 -function M.unsurround_coordinates(node_or_range, buf) - -- local lines = vim.split(s, "\n") - local srow, scol, erow, ecol = utils.get_range(node_or_range) - local lines = vim.api.nvim_buf_get_text(buf, srow - 1, scol - 1, erow - 1, ecol, {}) - local node_text = table.concat(lines, "\n") - local match_brackets = nil - for _, pair in ipairs(M.options.surrounds) do - local pattern = "^%" .. pair[1] .. ".*%" .. pair[2] .. "$" - match_brackets = string.match(node_text, pattern) - if match_brackets then - break +---Get inner coordinates for a surround node +---@param node TSNode +---@param buf integer? buffer number +---@return boolean is_surround +---@return table {srow, scol, erow, ecol} 1-based vim coordinates +function M.unsurround_coordinates(node, buf) + if surround.is_surround(node) then + local inner = surround.get_inner_range(node, buf) + if inner then + return true, inner end + -- Empty content (e.g., () or {}), use node range end - -- local match_brackets = string.match(node_text, "^%b{}$") - -- or string.match(node_text, "^%b()$") - -- or string.match(node_text, "^%b[]$") - if match_brackets == nil then - return false, { srow, scol, erow, ecol } - end - lines[1] = lines[1]:sub(2) - local nsrow, nscol = 0, 0 - for index, line in ipairs(lines) do - if line:match("%S") then - nsrow = index - nscol = line:len() - line:match("^%s*(.*)"):len() - break - end - end - - lines[#lines] = lines[#lines]:sub(1, -2) - local nerow, necol = #lines, 0 - for index = #lines, 1, -1 do - local line = lines[index] - if line:match("%S") then - nerow = index - necol = line:len() - line:match("^(.*%S)%s*$"):len() - break - end - end - - nsrow = srow + nsrow - 1 - nscol = nsrow == srow and scol + nscol + 1 or nscol + 1 - -- nerow = erow - nerow + 1 - nerow = srow + nerow - 1 - necol = nerow == erow and ecol - necol - 1 or lines[nerow - srow + 1]:len() - necol - - return true, { nsrow, nscol, nerow, necol } + return false, { utils.get_range(node) } end local function update_selection_by_node(node) local buf = api.nvim_get_current_buf() @@ -102,8 +62,17 @@ local function init_by_node(node) end function M.init_selection() count = vim.v.count1 - local node = ts_utils.get_node_at_cursor() + local node = vim.treesitter.get_node({ ignore_injections = false }) if not node then + -- No treesitter node available, try to handle gracefully + -- Check if treesitter is available for this filetype + local buf = api.nvim_get_current_buf() + local ok, parser = pcall(vim.treesitter.get_parser, buf) + if not ok or not parser then + -- No parser available for this filetype + vim.notify("Wildfire: No treesitter parser available for this filetype", vim.log.levels.WARN) + return + end return end init_by_node(node) @@ -134,9 +103,21 @@ local function select_incremental(get_parent) -- Initialize incremental selection with current selection if not nodes or #nodes == 0 then - local root = parsers.get_parser():parse()[1]:root() + -- Use native vim.treesitter API to get the parser + local ok, parser = pcall(vim.treesitter.get_parser, buf) + if not ok or not parser then + -- No parser available for this filetype, fallback to visual selection + return + end + local tree = parser:parse()[1] + if not tree then + return + end + local root = tree:root() local node = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol) - update_selection_by_node(node) + if node then + update_selection_by_node(node) + end return end @@ -145,17 +126,49 @@ local function select_incremental(get_parent) while true do local parent = get_parent(node) if not parent or parent == node then - -- Keep searching in the main tree - -- TODO: we should search on the parent tree of the current node. - local root = parsers.get_parser():parse()[1]:root() - parent = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol) - if not parent or root == node or parent == node then + -- Search in parent language tree for injected languages + local ok, parser = pcall(vim.treesitter.get_parser, buf) + if not ok or not parser then + utils.update_selection(buf, node) + return + end + + local range = { csrow - 1, cscol - 1, cerow - 1, cecol } + local current_lang_tree = parser:language_for_range(range) + local parent_lang_tree = current_lang_tree and current_lang_tree:parent() + + if parent_lang_tree then + local parent_tree = parent_lang_tree:tree_for_range(range) + if parent_tree then + local root = parent_tree:root() + local candidate = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol) + if candidate then + local csr, csc, cer, cec = utils.get_range(candidate) + local candidate_range = { csr, csc, cer, cec } + local current_range = { csrow, cscol, cerow, cecol } + + if utils.range_larger(candidate_range, current_range) + and not utils.range_match(candidate_range, current_range) then + parent = candidate + else + utils.update_selection(buf, node) + return + end + else + utils.update_selection(buf, node) + return + end + else + utils.update_selection(buf, node) + return + end + else utils.update_selection(buf, node) return end end node = parent - local nsrow, nscol, nerow, necol = ts_utils.get_vim_range({ node:range() }) + local nsrow, nscol, nerow, necol = utils.get_range(node) local larger_range = utils.range_larger({ nsrow, nscol, nerow, necol }, { csrow, cscol, cerow, cecol }) @@ -191,8 +204,28 @@ end function M.visual_inner() local buf = api.nvim_get_current_buf() local csrow, cscol, cerow, cecol = utils.visual_selection_range() - local _, selection = M.unsurround_coordinates({ csrow, cscol, cerow, cecol }, buf) - utils.update_selection(buf, selection) + + -- Get the treesitter node at the selection range + local ok, parser = pcall(vim.treesitter.get_parser, buf) + if not ok or not parser then + return + end + + local tree = parser:parse()[1] + if not tree then + return + end + + local root = tree:root() + local node = root:named_descendant_for_range(csrow - 1, cscol - 1, cerow - 1, cecol) + if not node then + return + end + + local is_surround, selection = M.unsurround_coordinates(node, buf) + if is_surround then + utils.update_selection(buf, selection) + end end local FUNCTION_DESCRIPTIONS = { diff --git a/lua/wildfire/surround.lua b/lua/wildfire/surround.lua new file mode 100644 index 0000000..6d6bf99 --- /dev/null +++ b/lua/wildfire/surround.lua @@ -0,0 +1,77 @@ +local ts = vim.treesitter + +local M = {} + +---Check if node is a surround type by examining its children +---A surround node has unnamed (anonymous) nodes as first and last children +---These unnamed nodes are typically delimiters like (, ), {, }, ", ', etc. +---@param node TSNode +---@return boolean +function M.is_surround(node) + local child_count = node:child_count() + if child_count < 2 then + return false + end + + local first = node:child(0) + local last = node:child(child_count - 1) + + -- Anonymous nodes are delimiters + return first and last and not first:named() and not last:named() +end + +---Get inner range (excluding delimiter nodes) +---@param node TSNode +---@param buf? integer buffer number (needed for edge case handling) +---@return table|nil {srow, scol, erow, ecol} 1-based vim coordinates +function M.get_inner_range(node, buf) + local child_count = node:child_count() + if child_count < 2 then + return nil + end + + local first = node:child(0) + local last = node:child(child_count - 1) + + if not first or not last or first:named() or last:named() then + return nil + end + + -- Get the end of first delimiter and start of last delimiter + local first_end_row, first_end_col = first:end_() + local last_start_row, last_start_col = last:start() + + -- Convert to 1-based vim coordinates + local srow = first_end_row + 1 + local scol = first_end_col + 1 + local erow = last_start_row + 1 + local ecol = last_start_col -- 0-based exclusive == 1-based inclusive + + buf = buf or vim.api.nvim_get_current_buf() + + -- Handle edge case: when first delimiter ends at line end, + -- content starts at next line's beginning + local first_line = vim.api.nvim_buf_get_lines(buf, srow - 1, srow, false)[1] + if first_line and scol > #first_line then + srow = srow + 1 + scol = 1 + end + + -- Handle edge case: when last delimiter is at line start (col 0), + -- content ends at previous line's end + if ecol < 1 and erow > srow then + erow = erow - 1 + local line = vim.api.nvim_buf_get_lines(buf, erow - 1, erow, false)[1] + ecol = line and #line or 1 + end + + -- Handle empty content: when delimiters are adjacent (e.g., () or {}) + -- Return nil to indicate no inner content + if srow > erow or (srow == erow and scol > ecol) then + return nil + end + + return { srow, scol, erow, ecol } +end + +return M diff --git a/lua/wildfire/utils.lua b/lua/wildfire/utils.lua index 5915356..f8c8a23 100644 --- a/lua/wildfire/utils.lua +++ b/lua/wildfire/utils.lua @@ -1,6 +1,5 @@ local api = vim.api -local ts_utils = require("nvim-treesitter.ts_utils") local ts = vim.treesitter local M = {} @@ -9,8 +8,14 @@ function M.get_range(node_or_range) if type(node_or_range) == "table" then start_row, start_col, end_row, end_col = unpack(node_or_range) else - local buf = api.nvim_get_current_buf() - start_row, start_col, end_row, end_col = ts_utils.get_vim_range({ ts.get_node_range(node_or_range) }, buf) + start_row, start_col, end_row, end_col = ts.get_node_range(node_or_range) + -- Convert 0-based to 1-based indexing to match vim coordinates + -- Note: treesitter end_col is exclusive, but vim coordinates are inclusive + start_row = start_row + 1 + start_col = start_col + 1 + end_row = end_row + 1 + -- end_col is already exclusive in treesitter (0-based), converting to 1-based inclusive means no change needed + -- because: 0-based exclusive position == 1-based inclusive position end return start_row, start_col, end_row, end_col ---@type integer, integer, integer, integer end @@ -63,15 +68,15 @@ end function M.print_selection(node_or_range) local bufnr = api.nvim_get_current_buf() - local lines + local node_text if type(node_or_range) == "table" then local srow, scol, erow, ecol srow, scol, erow, ecol = unpack(node_or_range) - lines = vim.api.nvim_buf_get_text(bufnr, srow - 1, scol - 1, erow - 1, ecol, {}) + local lines = vim.api.nvim_buf_get_text(bufnr, srow - 1, scol - 1, erow - 1, ecol, {}) + node_text = table.concat(lines, "\n") else - lines = ts_utils.get_node_text(node_or_range, bufnr) + node_text = vim.treesitter.get_node_text(node_or_range, bufnr) end - local node_text = table.concat(lines, "\n") print(node_text) end @@ -81,14 +86,44 @@ function M.update_selection(buf, node_or_range, selection_mode) if type(node_or_range) == "table" then start_row, start_col, end_row, end_col = unpack(node_or_range) else - start_row, start_col, end_row, end_col = ts_utils.get_vim_range({ ts.get_node_range(node_or_range) }, buf) + start_row, start_col, end_row, end_col = ts.get_node_range(node_or_range) + -- Convert 0-based to 1-based indexing to match vim coordinates + -- Note: treesitter end_col is exclusive, but vim coordinates are inclusive + start_row = start_row + 1 + start_col = start_col + 1 + end_row = end_row + 1 + -- end_col is already exclusive in treesitter (0-based), converting to 1-based inclusive means no change needed + -- because: 0-based exclusive position == 1-based inclusive position end + -- Validate buffer bounds to prevent cursor position errors + local line_count = api.nvim_buf_line_count(buf) + if start_row < 1 or end_row < 1 or start_row > line_count or end_row > line_count then + -- Invalid row range, cannot update selection + return + end + + -- Validate column bounds + if start_col < 1 or end_col < 0 then + -- Invalid column range, cannot update selection + return + end + + -- Get the actual line lengths to validate column positions + local start_line_text = api.nvim_buf_get_lines(buf, start_row - 1, start_row, false)[1] or "" + local end_line_text = api.nvim_buf_get_lines(buf, end_row - 1, end_row, false)[1] or "" + local start_line_len = #start_line_text + local end_line_len = #end_line_text + + -- Clamp column positions to valid ranges (0-indexed for nvim_win_set_cursor) + start_col = math.max(0, math.min(start_col - 1, start_line_len)) + end_col = math.max(0, math.min(end_col - 1, end_line_len)) + local v_table = { charwise = "v", linewise = "V", blockwise = "" } selection_mode = selection_mode or "charwise" -- Normalise selection_mode - if vim.tbl_contains(vim.tbl_keys(v_table), selection_mode) then + if v_table[selection_mode] then selection_mode = v_table[selection_mode] end @@ -103,8 +138,8 @@ function M.update_selection(buf, node_or_range, selection_mode) api.nvim_cmd({ cmd = "normal", bang = true, args = { selection_mode } }, {}) end - api.nvim_win_set_cursor(0, { start_row, start_col - 1 }) + api.nvim_win_set_cursor(0, { start_row, start_col }) vim.cmd("normal! o") - api.nvim_win_set_cursor(0, { end_row, end_col - 1 }) + api.nvim_win_set_cursor(0, { end_row, end_col }) end return M