diff --git a/Cargo.lock b/Cargo.lock index e47f216..0ae0a42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -316,14 +316,13 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.41" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" +checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" dependencies = [ - "find-msvc-tools", "jobserver", "libc", - "shlex", + "once_cell", ] [[package]] @@ -1116,12 +1115,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "find-msvc-tools" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" - [[package]] name = "flate2" version = "1.1.5" @@ -1947,9 +1940,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.30.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +checksum = "6b694a822684ddb75df4d657029161431bcb4a85c1856952f845b76912bc6fec" dependencies = [ "cc", "pkg-config", @@ -3046,6 +3039,28 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "reposcout-ast" +version = "0.1.0" +dependencies = [ + "once_cell", + "regex", + "reposcout-core", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tracing", + "tree-sitter", + "tree-sitter-c", + "tree-sitter-cpp", + "tree-sitter-go", + "tree-sitter-javascript", + "tree-sitter-python", + "tree-sitter-rust", + "tree-sitter-typescript", +] + [[package]] name = "reposcout-cache" version = "0.1.0" @@ -3068,6 +3083,7 @@ dependencies = [ "clap", "dirs 5.0.1", "reposcout-api", + "reposcout-ast", "reposcout-cache", "reposcout-core", "reposcout-semantic", @@ -3122,6 +3138,7 @@ dependencies = [ "dirs 6.0.0", "fastembed", "regex", + "reposcout-ast", "reposcout-core", "rmp-serde", "serde", @@ -3147,6 +3164,7 @@ dependencies = [ "open", "ratatui", "reposcout-api", + "reposcout-ast", "reposcout-cache", "reposcout-core", "reposcout-deps", @@ -3212,9 +3230,9 @@ checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce" [[package]] name = "ring" -version = "0.17.14" +version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +checksum = "e75ec5e92c4d8aede845126adc388046234541629e76029599ed35a003c7ed24" dependencies = [ "cc", "cfg-if", @@ -3248,9 +3266,9 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.32.1" +version = "0.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +checksum = "1cdbe9230a57259b37f7257d0aff38b8c9dbda3513edba2105e59b130189d82f" dependencies = [ "bitflags", "fallible-iterator", @@ -3477,12 +3495,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "signal-hook" version = "0.3.18" @@ -4132,6 +4144,86 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tree-sitter" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca" +dependencies = [ + "cc", + "regex", +] + +[[package]] +name = "tree-sitter-c" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f956d5351d62652864a4ff3ae861747e7a1940dc96c9998ae400ac0d3ce30427" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-cpp" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d509a22a992790d38f2c291961ff8a1ff016c437c7ec6befc9220b8eec8918c" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-go" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8d702a98d3c7e70e466456e58ff2b1ac550bf1e29b97e5770676d2fdabec00d" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-javascript" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8710a71bc6779e33811a8067bdda3ed08bed1733296ff915e44faf60f8c533d7" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-python" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-rust" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-typescript" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecb35d98a688378e56c18c9c159824fd16f730ccbea19aacf4f206e5d5438ed9" +dependencies = [ + "cc", + "tree-sitter", +] + [[package]] name = "try-lock" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index bca16c4..32e7b9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,9 @@ members = [ "crates/reposcout-tui", "crates/reposcout-api", "crates/reposcout-cache", - "crates/reposcout-deps", "crates/reposcout-semantic", + "crates/reposcout-deps", + "crates/reposcout-semantic", + "crates/reposcout-ast", ] resolver = "2" diff --git a/crates/reposcout-api/src/github.rs b/crates/reposcout-api/src/github.rs index 652e79c..501ad73 100644 --- a/crates/reposcout-api/src/github.rs +++ b/crates/reposcout-api/src/github.rs @@ -298,6 +298,12 @@ impl GitHubClient { "Response snippet: {}", &response_text[..response_text.len().min(1000)] ); + // Save full response to file for debugging + if let Err(write_err) = std::fs::write("/tmp/github_response_debug.json", &response_text) { + tracing::warn!("Failed to write debug response: {}", write_err); + } else { + tracing::error!("Full response saved to /tmp/github_response_debug.json"); + } GitHubError::ParseError(e) })?; Ok(search_result.items) @@ -513,8 +519,10 @@ pub struct CodeSearchItem { pub name: String, pub path: String, pub sha: String, - pub url: String, - pub git_url: String, + #[serde(default)] + pub url: Option, + #[serde(default)] + pub git_url: Option, pub html_url: String, pub repository: CodeSearchRepository, #[serde(default)] @@ -530,10 +538,20 @@ pub struct CodeSearchRepository { pub full_name: String, #[serde(default)] pub description: Option, - pub html_url: String, + #[serde(default)] + pub html_url: Option, pub owner: Owner, #[serde(default)] pub private: bool, + #[serde(default)] + pub stargazers_count: u32, + #[serde(default)] + pub language: Option, + // Additional fields that might be null or missing + #[serde(default)] + pub fork: Option, + #[serde(default)] + pub url: Option, } /// Text match containing the actual code snippet @@ -553,7 +571,8 @@ pub struct TextMatch { /// Individual match within a text fragment #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Match { - pub text: String, + #[serde(default)] + pub text: Option, pub indices: Vec, } @@ -588,7 +607,18 @@ pub struct GitHubRepo { pub struct Owner { pub login: String, pub id: u64, - pub avatar_url: String, + #[serde(default)] + pub avatar_url: Option, + #[serde(default)] + pub gravatar_id: Option, + #[serde(default)] + pub url: Option, + #[serde(default)] + pub html_url: Option, + // Flatten any additional unknown fields + #[serde(flatten)] + #[serde(default)] + pub additional_fields: std::collections::HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/reposcout-ast/Cargo.toml b/crates/reposcout-ast/Cargo.toml new file mode 100644 index 0000000..3722f54 --- /dev/null +++ b/crates/reposcout-ast/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "reposcout-ast" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true + +[dependencies] +# Internal dependencies +reposcout-core = { path = "../reposcout-core" } + +# Tree-sitter core +tree-sitter = "0.22" + +# Language grammars +tree-sitter-rust = "0.21" +tree-sitter-python = "0.21" +tree-sitter-javascript = "0.21" +tree-sitter-typescript = "0.21" +tree-sitter-go = "0.21" +tree-sitter-c = "0.21" +tree-sitter-cpp = "0.22" + +# Utilities +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +tracing = "0.1" +regex = "1.10" +once_cell = "1.19" + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } diff --git a/crates/reposcout-ast/src/error.rs b/crates/reposcout-ast/src/error.rs new file mode 100644 index 0000000..d2b7447 --- /dev/null +++ b/crates/reposcout-ast/src/error.rs @@ -0,0 +1,30 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum AstError { + #[error("Unsupported language: {0}")] + UnsupportedLanguage(String), + + #[error("Parse error: {0}")] + ParseError(String), + + #[error("Tree-sitter error: {0}")] + TreeSitterError(String), + + #[error("Extraction failed: {0}")] + ExtractionError(String), + + #[error("Query parsing error: {0}")] + QueryParseError(String), + + #[error("Timeout while parsing (exceeded {timeout_ms}ms)")] + ParseTimeout { timeout_ms: u64 }, + + #[error("File too large: {size} bytes (max: {max})")] + FileTooLarge { size: usize, max: usize }, + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), +} + +pub type Result = std::result::Result; diff --git a/crates/reposcout-ast/src/extractors/go.rs b/crates/reposcout-ast/src/extractors/go.rs new file mode 100644 index 0000000..faa0723 --- /dev/null +++ b/crates/reposcout-ast/src/extractors/go.rs @@ -0,0 +1,415 @@ +use super::AstExtractor; +use crate::error::{AstError, Result}; +use crate::parser::ParserCache; +use reposcout_core::models::{ + AstMetadata, FunctionSignature, ImportStatement, Parameter, TypeDefinition, TypeKind, + Visibility, +}; +use tree_sitter::Node; + +pub struct GoExtractor; + +impl AstExtractor for GoExtractor { + fn language(&self) -> &'static str { + "go" + } + + fn extract_functions(&self, node: Node, source: &str) -> Vec { + let mut functions = Vec::new(); + self.extract_functions_recursive(node, source, &mut functions); + functions + } + + fn extract_types(&self, node: Node, source: &str) -> Vec { + let mut types = Vec::new(); + self.extract_types_recursive(node, source, &mut types); + types + } + + fn extract_imports(&self, node: Node, source: &str) -> Vec { + let mut imports = Vec::new(); + self.extract_imports_recursive(node, source, &mut imports); + imports + } + + fn extract_all(&self, code: &str) -> Result { + let cache = ParserCache::get(); + let tree = cache + .parse(code, self.language()) + .map_err(|e| AstError::ParseError(e.to_string()))?; + + let root_node = tree.root_node(); + + let functions = self.extract_functions(root_node, code); + let types = self.extract_types(root_node, code); + let imports = self.extract_imports(root_node, code); + + // Generate structure summary + let mut summary_parts = Vec::new(); + if !functions.is_empty() { + let fn_names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect(); + summary_parts.push(format!("functions: {}", fn_names.join(", "))); + } + if !types.is_empty() { + let type_names: Vec<&str> = types.iter().map(|t| t.name.as_str()).collect(); + summary_parts.push(format!("types: {}", type_names.join(", "))); + } + + Ok(AstMetadata { + language: self.language().to_string(), + functions, + types, + imports, + structure_summary: summary_parts.join("; "), + parse_success: true, + parse_error: None, + }) + } +} + +impl GoExtractor { + fn extract_functions_recursive( + &self, + node: Node, + source: &str, + functions: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "function_declaration" | "method_declaration" => { + if let Some(func) = self.parse_function(child, source) { + functions.push(func); + } + } + _ => { + self.extract_functions_recursive(child, source, functions); + } + } + } + } + + fn parse_function(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut parameters = Vec::new(); + let mut return_type = None; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "parameter_list" => { + parameters = self.parse_parameters(child, source); + } + "type_identifier" | "pointer_type" | "slice_type" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return_type = Some(text.to_string()); + } + } + _ => {} + } + } + + if !name.is_empty() { + // In Go, visibility is based on capitalization + let visibility = if name.chars().next().unwrap_or('a').is_uppercase() { + Visibility::Public + } else { + Visibility::Private + }; + + Some(FunctionSignature { + name, + parameters, + return_type, + visibility, + is_async: false, // Go doesn't have async/await + is_generic: false, + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn parse_parameters(&self, node: Node, source: &str) -> Vec { + let mut parameters = Vec::new(); + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "parameter_declaration" { + if let Some(param) = self.parse_parameter(child, source) { + parameters.push(param); + } + } + } + + parameters + } + + fn parse_parameter(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut param_type = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "type_identifier" | "pointer_type" | "slice_type" | "array_type" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + param_type = Some(text.to_string()); + } + } + _ => {} + } + } + + if !name.is_empty() { + Some(Parameter { + name, + param_type, + is_optional: false, + }) + } else { + None + } + } + + fn extract_types_recursive( + &self, + node: Node, + source: &str, + types: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "type_declaration" { + if let Some(type_def) = self.parse_type_declaration(child, source) { + types.push(type_def); + } + } else { + self.extract_types_recursive(child, source, types); + } + } + } + + fn parse_type_declaration(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut kind = TypeKind::Struct; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "type_spec" { + let mut spec_cursor = child.walk(); + for spec_child in child.children(&mut spec_cursor) { + match spec_child.kind() { + "type_identifier" => { + if let Ok(text) = spec_child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + "struct_type" => { + kind = TypeKind::Struct; + } + "interface_type" => { + kind = TypeKind::Interface; + } + _ => {} + } + } + } + } + + if !name.is_empty() { + Some(TypeDefinition { + name, + kind, + fields: Vec::new(), + methods: Vec::new(), + implements: Vec::new(), + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn extract_imports_recursive( + &self, + node: Node, + source: &str, + imports: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "import_declaration" { + self.parse_import_declaration(child, source, imports); + } else { + self.extract_imports_recursive(child, source, imports); + } + } + } + + fn parse_import_declaration( + &self, + node: Node, + source: &str, + imports: &mut Vec, + ) { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "import_spec" | "import_spec_list" => { + if child.kind() == "import_spec" { + if let Some(import) = self.parse_import_spec(child, source) { + imports.push(import); + } + } else { + let mut list_cursor = child.walk(); + for list_child in child.children(&mut list_cursor) { + if list_child.kind() == "import_spec" { + if let Some(import) = self.parse_import_spec(list_child, source) { + imports.push(import); + } + } + } + } + } + _ => {} + } + } + } + + fn parse_import_spec(&self, node: Node, source: &str) -> Option { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "interpreted_string_literal" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + let module = text.trim_matches('"').to_string(); + return Some(ImportStatement { + module, + items: vec![], + is_wildcard: false, + }); + } + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_go_function() { + let code = r#" +package main + +func Add(a int, b int) int { + return a + b +} +"#; + + let extractor = GoExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "Add"); + assert_eq!(ast.functions[0].parameters.len(), 2); + assert_eq!(ast.functions[0].visibility, Visibility::Public); // Capitalized + } + + #[test] + fn test_extract_struct() { + let code = r#" +package main + +type User struct { + Name string + Age int +} +"#; + + let extractor = GoExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + assert_eq!(ast.types[0].name, "User"); + assert_eq!(ast.types[0].kind, TypeKind::Struct); + } + + #[test] + fn test_extract_interface() { + let code = r#" +package main + +type Reader interface { + Read(p []byte) (n int, err error) +} +"#; + + let extractor = GoExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + assert_eq!(ast.types[0].name, "Reader"); + assert_eq!(ast.types[0].kind, TypeKind::Interface); + } + + #[test] + fn test_extract_imports() { + let code = r#" +package main + +import ( + "fmt" + "net/http" +) +"#; + + let extractor = GoExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.imports.len(), 2); + assert!(ast.imports.iter().any(|i| i.module == "fmt")); + assert!(ast.imports.iter().any(|i| i.module == "net/http")); + } + + #[test] + fn test_visibility_from_case() { + let code = r#" +package main + +func PublicFunc() {} +func privateFunc() {} +"#; + + let extractor = GoExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 2); + assert_eq!(ast.functions[0].visibility, Visibility::Public); + assert_eq!(ast.functions[1].visibility, Visibility::Private); + } +} diff --git a/crates/reposcout-ast/src/extractors/javascript.rs b/crates/reposcout-ast/src/extractors/javascript.rs new file mode 100644 index 0000000..5aeb881 --- /dev/null +++ b/crates/reposcout-ast/src/extractors/javascript.rs @@ -0,0 +1,666 @@ +use super::AstExtractor; +use crate::error::{AstError, Result}; +use crate::parser::ParserCache; +use reposcout_core::models::{ + AstMetadata, FunctionSignature, ImportStatement, Parameter, TypeDefinition, TypeKind, + Visibility, +}; +use tree_sitter::Node; + +pub struct JavaScriptExtractor { + language: &'static str, +} + +impl JavaScriptExtractor { + pub fn new(language: &'static str) -> Self { + Self { language } + } +} + +impl AstExtractor for JavaScriptExtractor { + fn language(&self) -> &'static str { + self.language + } + + fn extract_functions(&self, node: Node, source: &str) -> Vec { + let mut functions = Vec::new(); + self.extract_functions_recursive(node, source, &mut functions); + functions + } + + fn extract_types(&self, node: Node, source: &str) -> Vec { + let mut types = Vec::new(); + self.extract_types_recursive(node, source, &mut types); + types + } + + fn extract_imports(&self, node: Node, source: &str) -> Vec { + let mut imports = Vec::new(); + self.extract_imports_recursive(node, source, &mut imports); + imports + } + + fn extract_all(&self, code: &str) -> Result { + let cache = ParserCache::get(); + let tree = cache + .parse(code, self.language()) + .map_err(|e| AstError::ParseError(e.to_string()))?; + + let root_node = tree.root_node(); + + let functions = self.extract_functions(root_node, code); + let types = self.extract_types(root_node, code); + let imports = self.extract_imports(root_node, code); + + // Generate structure summary + let mut summary_parts = Vec::new(); + if !functions.is_empty() { + let fn_names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect(); + summary_parts.push(format!("functions: {}", fn_names.join(", "))); + } + if !types.is_empty() { + let type_names: Vec<&str> = types.iter().map(|t| t.name.as_str()).collect(); + summary_parts.push(format!("classes: {}", type_names.join(", "))); + } + + Ok(AstMetadata { + language: self.language().to_string(), + functions, + types, + imports, + structure_summary: summary_parts.join("; "), + parse_success: true, + parse_error: None, + }) + } +} + +impl JavaScriptExtractor { + fn extract_functions_recursive( + &self, + node: Node, + source: &str, + functions: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "function_declaration" | "function" | "method_definition" | "function_signature" => { + if let Some(func) = self.parse_function(child, source) { + functions.push(func); + } + } + "lexical_declaration" | "variable_declaration" => { + // Check for arrow functions and function expressions + if let Some(func) = self.parse_variable_function(child, source) { + functions.push(func); + } + } + _ => { + self.extract_functions_recursive(child, source, functions); + } + } + } + } + + fn parse_function(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut parameters = Vec::new(); + let mut return_type = None; + let mut is_async = false; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "async" => { + is_async = true; + } + "identifier" | "property_identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "formal_parameters" => { + parameters = self.parse_parameters(child, source); + } + "type_annotation" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return_type = Some(text.trim_start_matches(':').trim().to_string()); + } + } + _ => {} + } + } + + // For methods without explicit names, try to get from parent + if name.is_empty() { + if let Some(parent) = node.parent() { + if parent.kind() == "pair" { + // Object method + if let Some(key_node) = parent.child_by_field_name("key") { + if let Ok(text) = key_node.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + } + } + + if !name.is_empty() { + Some(FunctionSignature { + name, + parameters, + return_type, + visibility: Visibility::Public, // JS doesn't have explicit visibility + is_async, + is_generic: false, + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn parse_variable_function(&self, node: Node, source: &str) -> Option { + // Look for: const foo = async (params) => {} + // Or: const foo = function(params) {} + let mut name = String::new(); + let mut _is_arrow_func = false; + let mut _is_func_expr = false; + let mut is_async = false; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "variable_declarator" { + let mut decl_cursor = child.walk(); + for decl_child in child.children(&mut decl_cursor) { + match decl_child.kind() { + "identifier" => { + if let Ok(text) = decl_child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + "arrow_function" => { + _is_arrow_func = true; + // Check for async + let mut arrow_cursor = decl_child.walk(); + for arrow_child in decl_child.children(&mut arrow_cursor) { + if arrow_child.kind() == "async" { + is_async = true; + break; + } + } + if !name.is_empty() { + return self.parse_arrow_function(decl_child, source, &name, is_async); + } + } + "function" | "function_expression" => { + _is_func_expr = true; + if !name.is_empty() { + return self.parse_function(decl_child, source).map(|mut f| { + f.name = name.clone(); + f + }); + } + } + _ => {} + } + } + } + } + + None + } + + fn parse_arrow_function( + &self, + node: Node, + source: &str, + name: &str, + is_async: bool, + ) -> Option { + let mut parameters = Vec::new(); + let mut return_type = None; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "formal_parameters" | "identifier" => { + if child.kind() == "identifier" { + // Single parameter without parentheses + if let Ok(text) = child.utf8_text(source.as_bytes()) { + parameters.push(Parameter { + name: text.to_string(), + param_type: None, + is_optional: false, + }); + } + } else { + parameters = self.parse_parameters(child, source); + } + } + "type_annotation" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return_type = Some(text.trim_start_matches(':').trim().to_string()); + } + } + _ => {} + } + } + + Some(FunctionSignature { + name: name.to_string(), + parameters, + return_type, + visibility: Visibility::Public, + is_async, + is_generic: false, + doc_comment: None, + line_number, + }) + } + + fn parse_parameters(&self, node: Node, source: &str) -> Vec { + let mut parameters = Vec::new(); + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "required_parameter" | "optional_parameter" | "identifier" => { + if let Some(param) = self.parse_parameter(child, source) { + parameters.push(param); + } + } + _ => {} + } + } + + parameters + } + + fn parse_parameter(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut param_type = None; + let is_optional = node.kind() == "optional_parameter"; + + if node.kind() == "identifier" { + if let Ok(text) = node.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } else { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + "type_annotation" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + param_type = Some(text.trim_start_matches(':').trim().to_string()); + } + } + _ => {} + } + } + } + + if !name.is_empty() { + Some(Parameter { + name, + param_type, + is_optional, + }) + } else { + None + } + } + + fn extract_types_recursive( + &self, + node: Node, + source: &str, + types: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "class_declaration" => { + if let Some(type_def) = self.parse_class(child, source) { + types.push(type_def); + } + } + "interface_declaration" => { + if let Some(type_def) = self.parse_interface(child, source) { + types.push(type_def); + } + } + "type_alias_declaration" => { + if let Some(type_def) = self.parse_type_alias(child, source) { + types.push(type_def); + } + } + _ => { + self.extract_types_recursive(child, source, types); + } + } + } + } + + fn parse_class(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut methods = Vec::new(); + let mut implements = Vec::new(); + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" | "type_identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "class_heritage" => { + // Parse extends/implements + let mut heritage_cursor = child.walk(); + for heritage_child in child.children(&mut heritage_cursor) { + if heritage_child.kind() == "identifier" || heritage_child.kind() == "type_identifier" { + if let Ok(text) = heritage_child.utf8_text(source.as_bytes()) { + implements.push(text.to_string()); + } + } + } + } + "class_body" => { + // Extract method names + let mut body_cursor = child.walk(); + for body_child in child.children(&mut body_cursor) { + if body_child.kind() == "method_definition" { + if let Some(method_name) = self.get_method_name(body_child, source) { + methods.push(method_name); + } + } + } + } + _ => {} + } + } + + if !name.is_empty() { + Some(TypeDefinition { + name, + kind: TypeKind::Class, + fields: Vec::new(), + methods, + implements, + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn parse_interface(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "type_identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + break; + } + } + } + + if !name.is_empty() { + Some(TypeDefinition { + name, + kind: TypeKind::Interface, + fields: Vec::new(), + methods: Vec::new(), + implements: Vec::new(), + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn parse_type_alias(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "type_identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + break; + } + } + } + + if !name.is_empty() { + Some(TypeDefinition { + name, + kind: TypeKind::Type, + fields: Vec::new(), + methods: Vec::new(), + implements: Vec::new(), + doc_comment: None, + line_number, + }) + } else { + None + } + } + + fn get_method_name(&self, node: Node, source: &str) -> Option { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "property_identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return Some(text.to_string()); + } + } + } + None + } + + fn extract_imports_recursive( + &self, + node: Node, + source: &str, + imports: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "import_statement" { + if let Some(import) = self.parse_import(child, source) { + imports.push(import); + } + } else { + self.extract_imports_recursive(child, source, imports); + } + } + } + + fn parse_import(&self, node: Node, source: &str) -> Option { + let mut module = String::new(); + let mut items = Vec::new(); + let mut is_wildcard = false; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "string" => { + // Module path + if let Ok(text) = child.utf8_text(source.as_bytes()) { + module = text.trim_matches('"').trim_matches('\'').to_string(); + } + } + "import_clause" => { + // Parse named imports + let mut clause_cursor = child.walk(); + for clause_child in child.children(&mut clause_cursor) { + match clause_child.kind() { + "identifier" => { + if let Ok(text) = clause_child.utf8_text(source.as_bytes()) { + items.push(text.to_string()); + } + } + "named_imports" => { + let mut named_cursor = clause_child.walk(); + for named_child in clause_child.children(&mut named_cursor) { + if named_child.kind() == "import_specifier" { + if let Some(name) = self.get_import_specifier_name(named_child, source) { + items.push(name); + } + } + } + } + "namespace_import" => { + is_wildcard = true; + } + _ => {} + } + } + } + _ => {} + } + } + + if !module.is_empty() { + Some(ImportStatement { + module, + items, + is_wildcard, + }) + } else { + None + } + } + + fn get_import_specifier_name(&self, node: Node, source: &str) -> Option { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return Some(text.to_string()); + } + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_function_declaration() { + let code = r#" +function greet(name) { + return `Hello, ${name}!`; +} +"#; + + let extractor = JavaScriptExtractor::new("javascript"); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "greet"); + assert_eq!(ast.functions[0].parameters.len(), 1); + assert_eq!(ast.functions[0].parameters[0].name, "name"); + } + + #[test] + fn test_extract_arrow_function() { + let code = r#" +const add = (a, b) => a + b; +"#; + + let extractor = JavaScriptExtractor::new("javascript"); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "add"); + assert_eq!(ast.functions[0].parameters.len(), 2); + } + + #[test] + fn test_extract_async_function() { + let code = r#" +async function fetchData(url) { + const response = await fetch(url); + return response.json(); +} +"#; + + let extractor = JavaScriptExtractor::new("javascript"); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "fetchData"); + assert!(ast.functions[0].is_async); + } + + #[test] + fn test_extract_class() { + let code = r#" +class User { + constructor(name) { + this.name = name; + } + + greet() { + return `Hello, ${this.name}`; + } +} +"#; + + let extractor = JavaScriptExtractor::new("javascript"); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + assert_eq!(ast.types[0].name, "User"); + assert_eq!(ast.types[0].kind, TypeKind::Class); + assert_eq!(ast.types[0].methods.len(), 2); + assert!(ast.types[0].methods.contains(&"constructor".to_string())); + assert!(ast.types[0].methods.contains(&"greet".to_string())); + } + + #[test] + fn test_extract_imports() { + let code = r#" +import React from 'react'; +import { useState, useEffect } from 'react'; +import * as utils from './utils'; +"#; + + let extractor = JavaScriptExtractor::new("javascript"); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.imports.len(), 3); + assert!(ast.imports.iter().any(|i| i.module == "react" && i.items.contains(&"React".to_string()))); + assert!(ast.imports.iter().any(|i| i.module == "react" && i.items.contains(&"useState".to_string()))); + assert!(ast.imports.iter().any(|i| i.module == "./utils" && i.is_wildcard)); + } +} diff --git a/crates/reposcout-ast/src/extractors/mod.rs b/crates/reposcout-ast/src/extractors/mod.rs new file mode 100644 index 0000000..d1c4fde --- /dev/null +++ b/crates/reposcout-ast/src/extractors/mod.rs @@ -0,0 +1,75 @@ +use crate::error::Result; +use reposcout_core::models::{AstMetadata, FunctionSignature, ImportStatement, TypeDefinition}; +use tree_sitter::Node; + +pub mod rust; +pub mod python; +pub mod javascript; +pub mod go; + +/// Trait for language-specific AST extractors +pub trait AstExtractor { + /// Get the language name + fn language(&self) -> &'static str; + + /// Extract functions from AST + fn extract_functions(&self, node: Node, source: &str) -> Vec; + + /// Extract types (classes, structs, etc.) from AST + fn extract_types(&self, node: Node, source: &str) -> Vec; + + /// Extract imports from AST + fn extract_imports(&self, node: Node, source: &str) -> Vec; + + /// Extract all metadata from code + fn extract_all(&self, code: &str) -> Result { + let cache = crate::parser::ParserCache::get(); + let tree = cache.parse(code, self.language())?; + let root = tree.root_node(); + + Ok(AstMetadata { + language: self.language().to_string(), + functions: self.extract_functions(root, code), + types: self.extract_types(root, code), + imports: self.extract_imports(root, code), + structure_summary: self.generate_summary(root, code), + parse_success: true, + parse_error: None, + }) + } + + /// Generate a text summary of the code structure for embeddings + fn generate_summary(&self, node: Node, source: &str) -> String { + let mut parts = Vec::new(); + + let functions = self.extract_functions(node, source); + if !functions.is_empty() { + let names: Vec<_> = functions.iter().map(|f| f.name.as_str()).collect(); + parts.push(format!("functions: {}", names.join(", "))); + } + + let types = self.extract_types(node, source); + if !types.is_empty() { + let names: Vec<_> = types.iter().map(|t| t.name.as_str()).collect(); + parts.push(format!("types: {}", names.join(", "))); + } + + parts.join(" | ") + } +} + +/// Get extractor for a specific language +pub fn get_extractor(language: &str) -> Option> { + match language.to_lowercase().as_str() { + "rust" => Some(Box::new(rust::RustExtractor)), + "python" | "py" => Some(Box::new(python::PythonExtractor)), + "javascript" | "js" => Some(Box::new(javascript::JavaScriptExtractor::new("javascript"))), + "typescript" | "ts" => Some(Box::new(javascript::JavaScriptExtractor::new("typescript"))), + "tsx" => Some(Box::new(javascript::JavaScriptExtractor::new("tsx"))), + "go" => Some(Box::new(go::GoExtractor)), + // TODO: Add C and C++ as they're implemented + // "c" => Some(Box::new(c::CExtractor)), + // "cpp" | "c++" => Some(Box::new(cpp::CppExtractor)), + _ => None, + } +} diff --git a/crates/reposcout-ast/src/extractors/python.rs b/crates/reposcout-ast/src/extractors/python.rs new file mode 100644 index 0000000..bfe12de --- /dev/null +++ b/crates/reposcout-ast/src/extractors/python.rs @@ -0,0 +1,532 @@ +use super::AstExtractor; +use crate::error::{AstError, Result}; +use crate::parser::ParserCache; +use reposcout_core::models::{ + AstMetadata, FunctionSignature, ImportStatement, Parameter, TypeDefinition, TypeKind, + Visibility, +}; +use tree_sitter::Node; + +pub struct PythonExtractor; + +impl AstExtractor for PythonExtractor { + fn language(&self) -> &'static str { + "python" + } + + fn extract_functions(&self, node: Node, source: &str) -> Vec { + let mut functions = Vec::new(); + self.extract_functions_recursive(node, source, &mut functions, false); + functions + } + + fn extract_types(&self, node: Node, source: &str) -> Vec { + let mut types = Vec::new(); + self.extract_types_recursive(node, source, &mut types); + types + } + + fn extract_imports(&self, node: Node, source: &str) -> Vec { + let mut imports = Vec::new(); + self.extract_imports_recursive(node, source, &mut imports); + imports + } + + fn extract_all(&self, code: &str) -> Result { + let cache = ParserCache::get(); + let tree = cache + .parse(code, self.language()) + .map_err(|e| AstError::ParseError(e.to_string()))?; + + let root_node = tree.root_node(); + + let functions = self.extract_functions(root_node, code); + let types = self.extract_types(root_node, code); + let imports = self.extract_imports(root_node, code); + + // Generate structure summary + let mut summary_parts = Vec::new(); + if !functions.is_empty() { + let fn_names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect(); + summary_parts.push(format!("functions: {}", fn_names.join(", "))); + } + if !types.is_empty() { + let type_names: Vec<&str> = types.iter().map(|t| t.name.as_str()).collect(); + summary_parts.push(format!("classes: {}", type_names.join(", "))); + } + + Ok(AstMetadata { + language: self.language().to_string(), + functions, + types, + imports, + structure_summary: summary_parts.join("; "), + parse_success: true, + parse_error: None, + }) + } +} + +impl PythonExtractor { + fn extract_functions_recursive( + &self, + node: Node, + source: &str, + functions: &mut Vec, + inside_class: bool, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "function_definition" => { + if let Some(func) = self.parse_function(child, source, inside_class) { + functions.push(func); + } + } + "class_definition" => { + // Extract methods from class + self.extract_functions_recursive(child, source, functions, true); + } + _ => { + self.extract_functions_recursive(child, source, functions, inside_class); + } + } + } + } + + fn parse_function( + &self, + node: Node, + source: &str, + _inside_class: bool, + ) -> Option { + let mut name = String::new(); + let mut parameters = Vec::new(); + let mut return_type = None; + let mut is_async = false; + let mut doc_comment = None; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "async" => { + is_async = true; + } + "identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "parameters" => { + parameters = self.parse_parameters(child, source); + } + "type" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return_type = Some(text.trim_start_matches("->").trim().to_string()); + } + } + "block" => { + // Try to extract docstring + if let Some(first_stmt) = child.child(1) { + // Skip newline/indent + if first_stmt.kind() == "expression_statement" { + if let Some(string_node) = first_stmt.child(0) { + if string_node.kind() == "string" { + if let Ok(text) = string_node.utf8_text(source.as_bytes()) { + doc_comment = Some( + text.trim_matches('"') + .trim_matches('\'') + .trim() + .to_string(), + ); + } + } + } + } + } + } + _ => {} + } + } + + if !name.is_empty() { + // Determine visibility based on naming convention + let visibility = if name.starts_with("__") && !name.ends_with("__") { + Visibility::Private + } else if name.starts_with('_') { + Visibility::Protected + } else { + Visibility::Public + }; + + // Skip 'self' and 'cls' parameters for cleaner display + parameters.retain(|p| p.name != "self" && p.name != "cls"); + + Some(FunctionSignature { + name, + parameters, + return_type, + visibility, + is_async, + is_generic: false, // Python doesn't have explicit generics like Rust + doc_comment, + line_number, + }) + } else { + None + } + } + + fn parse_parameters(&self, node: Node, source: &str) -> Vec { + let mut parameters = Vec::new(); + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "identifier" || child.kind() == "typed_parameter" { + if let Some(param) = self.parse_parameter(child, source) { + parameters.push(param); + } + } else if child.kind() == "default_parameter" || child.kind() == "typed_default_parameter" + { + if let Some(param) = self.parse_default_parameter(child, source) { + parameters.push(param); + } + } + } + + parameters + } + + fn parse_parameter(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut param_type = None; + + if node.kind() == "identifier" { + if let Ok(text) = node.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } else if node.kind() == "typed_parameter" { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + "type" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + param_type = Some(text.trim_start_matches(':').trim().to_string()); + } + } + _ => {} + } + } + } + + if !name.is_empty() { + Some(Parameter { + name, + param_type, + is_optional: false, + }) + } else { + None + } + } + + fn parse_default_parameter(&self, node: Node, source: &str) -> Option { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "identifier" || child.kind() == "typed_parameter" { + return self.parse_parameter(child, source); + } + } + None + } + + fn extract_types_recursive( + &self, + node: Node, + source: &str, + types: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "class_definition" { + if let Some(type_def) = self.parse_class(child, source) { + types.push(type_def); + } + } + self.extract_types_recursive(child, source, types); + } + } + + fn parse_class(&self, node: Node, source: &str) -> Option { + let mut name = String::new(); + let mut methods = Vec::new(); + let mut implements = Vec::new(); + let mut doc_comment = None; + let line_number = node.start_position().row + 1; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" => { + if name.is_empty() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = text.to_string(); + } + } + } + "argument_list" => { + // Parse base classes + let mut arg_cursor = child.walk(); + for arg in child.children(&mut arg_cursor) { + if arg.kind() == "identifier" { + if let Ok(text) = arg.utf8_text(source.as_bytes()) { + implements.push(text.to_string()); + } + } + } + } + "block" => { + // Extract docstring + if let Some(first_stmt) = child.child(1) { + if first_stmt.kind() == "expression_statement" { + if let Some(string_node) = first_stmt.child(0) { + if string_node.kind() == "string" { + if let Ok(text) = string_node.utf8_text(source.as_bytes()) { + doc_comment = Some( + text.trim_matches('"') + .trim_matches('\'') + .trim() + .to_string(), + ); + } + } + } + } + } + + // Extract method names + let mut block_cursor = child.walk(); + for block_child in child.children(&mut block_cursor) { + if block_child.kind() == "function_definition" { + if let Some(method_name) = self.get_function_name(block_child, source) { + methods.push(method_name); + } + } + } + } + _ => {} + } + } + + if !name.is_empty() { + Some(TypeDefinition { + name, + kind: TypeKind::Class, + fields: Vec::new(), // Python doesn't have explicit field declarations + methods, + implements, + doc_comment, + line_number, + }) + } else { + None + } + } + + fn get_function_name(&self, node: Node, source: &str) -> Option { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return Some(text.to_string()); + } + } + } + None + } + + fn extract_imports_recursive( + &self, + node: Node, + source: &str, + imports: &mut Vec, + ) { + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + match child.kind() { + "import_statement" | "import_from_statement" => { + if let Some(import) = self.parse_import(child, source) { + imports.push(import); + } + } + _ => { + self.extract_imports_recursive(child, source, imports); + } + } + } + } + + fn parse_import(&self, node: Node, source: &str) -> Option { + if let Ok(text) = node.utf8_text(source.as_bytes()) { + let text = text.trim(); + + // Parse "import module" or "from module import items" + if text.starts_with("from ") { + // from module import item1, item2 + let parts: Vec<&str> = text.split("import").collect(); + if parts.len() == 2 { + let module = parts[0] + .trim_start_matches("from ") + .trim() + .to_string(); + let items_str = parts[1].trim(); + let is_wildcard = items_str == "*"; + let items = if is_wildcard { + vec![] + } else { + items_str + .split(',') + .map(|s| s.trim().to_string()) + .collect() + }; + + return Some(ImportStatement { + module, + items, + is_wildcard, + }); + } + } else if text.starts_with("import ") { + // import module + let module = text.trim_start_matches("import ").trim().to_string(); + if !module.is_empty() { + return Some(ImportStatement { + module, + items: vec![], + is_wildcard: false, + }); + } + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_python_function() { + let code = r#" +def greet(name: str) -> str: + """Greet a person by name.""" + return f"Hello, {name}!" +"#; + + let extractor = PythonExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "greet"); + assert_eq!(ast.functions[0].parameters.len(), 1); + assert_eq!(ast.functions[0].parameters[0].name, "name"); + assert_eq!( + ast.functions[0].parameters[0].param_type, + Some("str".to_string()) + ); + assert_eq!(ast.functions[0].return_type, Some("str".to_string())); + } + + #[test] + fn test_extract_async_function() { + let code = r#" +async def fetch_data(url: str): + """Fetch data from URL.""" + pass +"#; + + let extractor = PythonExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "fetch_data"); + assert!(ast.functions[0].is_async); + } + + #[test] + fn test_extract_class() { + let code = r#" +class User: + """A user class.""" + + def __init__(self, name: str): + self.name = name + + def greet(self): + return f"Hello, {self.name}" +"#; + + let extractor = PythonExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + assert_eq!(ast.types[0].name, "User"); + assert_eq!(ast.types[0].kind, TypeKind::Class); + assert_eq!(ast.types[0].methods.len(), 2); + assert!(ast.types[0].methods.contains(&"__init__".to_string())); + assert!(ast.types[0].methods.contains(&"greet".to_string())); + } + + #[test] + fn test_visibility_from_naming() { + let code = r#" +def public_func(): + pass + +def _protected_func(): + pass + +def __private_func(): + pass +"#; + + let extractor = PythonExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 3); + assert_eq!(ast.functions[0].visibility, Visibility::Public); + assert_eq!(ast.functions[1].visibility, Visibility::Protected); + assert_eq!(ast.functions[2].visibility, Visibility::Private); + } + + #[test] + fn test_extract_imports() { + let code = r#" +import os +import sys +from typing import List, Dict +from pathlib import Path +"#; + + let extractor = PythonExtractor; + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.imports.len(), 4); + assert!(ast.imports.iter().any(|i| i.module == "os")); + assert!(ast.imports.iter().any(|i| i.module == "sys")); + assert!(ast.imports.iter().any(|i| i.module == "typing")); + assert!(ast.imports.iter().any(|i| i.module == "pathlib")); + } +} diff --git a/crates/reposcout-ast/src/extractors/rust.rs b/crates/reposcout-ast/src/extractors/rust.rs new file mode 100644 index 0000000..7176902 --- /dev/null +++ b/crates/reposcout-ast/src/extractors/rust.rs @@ -0,0 +1,358 @@ +use super::AstExtractor; +use reposcout_core::models::{ + Field, FunctionSignature, ImportStatement, Parameter, TypeDefinition, TypeKind, Visibility, +}; +use tree_sitter::Node; + +pub struct RustExtractor; + +impl AstExtractor for RustExtractor { + fn language(&self) -> &'static str { + "rust" + } + + fn extract_functions(&self, node: Node, source: &str) -> Vec { + let mut functions = Vec::new(); + visit_functions(node, source, &mut functions); + functions + } + + fn extract_types(&self, node: Node, source: &str) -> Vec { + let mut types = Vec::new(); + visit_types(node, source, &mut types); + types + } + + fn extract_imports(&self, node: Node, source: &str) -> Vec { + let mut imports = Vec::new(); + visit_imports(node, source, &mut imports); + imports + } +} + +/// Recursively visit nodes to find function declarations +fn visit_functions(node: Node, source: &str, functions: &mut Vec) { + if node.kind() == "function_item" { + if let Some(func) = parse_function(node, source) { + functions.push(func); + } + } + + // Recurse to children + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + visit_functions(child, source, functions); + } +} + +/// Parse a function node into FunctionSignature +fn parse_function(node: Node, source: &str) -> Option { + let mut name = None; + let mut parameters = Vec::new(); + let mut return_type = None; + let mut visibility = Visibility::Private; + let mut is_async = false; + let mut found_arrow = false; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "visibility_modifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + if text.contains("pub") { + visibility = Visibility::Public; + } + } + } + "identifier" => { + if name.is_none() && !found_arrow { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = Some(text.to_string()); + } + } + } + "parameters" => { + parameters = parse_parameters(child, source); + } + "->" => { + found_arrow = true; + } + _ if found_arrow + && return_type.is_none() + && (child.kind().ends_with("_type") || child.kind() == "type_identifier") => + { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + return_type = Some(text.to_string()); + } + } + _ => {} + } + } + + // Check for async modifier by examining the node text + if let Ok(text) = node.utf8_text(source.as_bytes()) { + is_async = text.contains("async fn"); + } + + Some(FunctionSignature { + name: name?, + parameters, + return_type, + visibility, + is_async, + is_generic: false, // TODO: detect generics + doc_comment: None, // TODO: extract doc comments + line_number: node.start_position().row + 1, + }) +} + +/// Parse function parameters +fn parse_parameters(node: Node, source: &str) -> Vec { + let mut params = Vec::new(); + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "parameter" || child.kind() == "self_parameter" { + if let Some(param) = parse_parameter(child, source) { + params.push(param); + } + } + } + + params +} + +/// Parse a single parameter +fn parse_parameter(node: Node, source: &str) -> Option { + let mut name = None; + let mut param_type = None; + + // Handle self parameter specially + if node.kind() == "self_parameter" { + if let Ok(text) = node.utf8_text(source.as_bytes()) { + return Some(Parameter { + name: text.to_string(), + param_type: None, + is_optional: false, + }); + } + } + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "identifier" | "pattern" => { + if name.is_none() { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + // For patterns, extract just the identifier + let clean_name = text.split(':').next().unwrap_or(text).trim(); + name = Some(clean_name.to_string()); + } + } + } + _ if child.kind().ends_with("_type") || child.kind() == "type_identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + param_type = Some(text.to_string()); + } + } + _ => {} + } + } + + Some(Parameter { + name: name?, + param_type, + is_optional: false, + }) +} + +/// Recursively visit nodes to find type declarations +fn visit_types(node: Node, source: &str, types: &mut Vec) { + match node.kind() { + "struct_item" => { + if let Some(type_def) = parse_struct(node, source) { + types.push(type_def); + } + } + "enum_item" => { + if let Some(type_def) = parse_enum(node, source) { + types.push(type_def); + } + } + "trait_item" => { + if let Some(type_def) = parse_trait(node, source) { + types.push(type_def); + } + } + _ => {} + } + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + visit_types(child, source, types); + } +} + +/// Parse a struct definition +fn parse_struct(node: Node, source: &str) -> Option { + let mut name = None; + let mut fields = Vec::new(); + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "type_identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = Some(text.to_string()); + } + } + "field_declaration_list" => { + fields = parse_fields(child, source); + } + _ => {} + } + } + + Some(TypeDefinition { + name: name?, + kind: TypeKind::Struct, + fields, + methods: Vec::new(), // TODO: associate impl blocks + implements: Vec::new(), + doc_comment: None, + line_number: node.start_position().row + 1, + }) +} + +/// Parse struct fields +fn parse_fields(node: Node, source: &str) -> Vec { + let mut fields = Vec::new(); + let mut cursor = node.walk(); + + for child in node.children(&mut cursor) { + if child.kind() == "field_declaration" { + if let Some(field) = parse_field(child, source) { + fields.push(field); + } + } + } + + fields +} + +/// Parse a single field +fn parse_field(node: Node, source: &str) -> Option { + let mut name = None; + let mut field_type = None; + let mut visibility = Visibility::Private; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "visibility_modifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + if text.contains("pub") { + visibility = Visibility::Public; + } + } + } + "field_identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = Some(text.to_string()); + } + } + _ if child.kind().ends_with("_type") || child.kind() == "type_identifier" => { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + field_type = Some(text.to_string()); + } + } + _ => {} + } + } + + Some(Field { + name: name?, + field_type, + visibility, + }) +} + +/// Parse an enum definition +fn parse_enum(node: Node, source: &str) -> Option { + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "type_identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = Some(text.to_string()); + } + break; + } + } + + Some(TypeDefinition { + name: name?, + kind: TypeKind::Enum, + fields: Vec::new(), + methods: Vec::new(), + implements: Vec::new(), + doc_comment: None, + line_number: node.start_position().row + 1, + }) +} + +/// Parse a trait definition +fn parse_trait(node: Node, source: &str) -> Option { + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.kind() == "type_identifier" { + if let Ok(text) = child.utf8_text(source.as_bytes()) { + name = Some(text.to_string()); + } + break; + } + } + + Some(TypeDefinition { + name: name?, + kind: TypeKind::Trait, + fields: Vec::new(), + methods: Vec::new(), + implements: Vec::new(), + doc_comment: None, + line_number: node.start_position().row + 1, + }) +} + +/// Recursively visit nodes to find use declarations +fn visit_imports(node: Node, source: &str, imports: &mut Vec) { + if node.kind() == "use_declaration" { + if let Some(import) = parse_use(node, source) { + imports.push(import); + } + } + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + visit_imports(child, source, imports); + } +} + +/// Parse a use declaration +fn parse_use(node: Node, source: &str) -> Option { + // Simple implementation: just get the whole use statement as module name + let text = node.utf8_text(source.as_bytes()).ok()?; + let cleaned = text + .trim_start_matches("use ") + .trim_end_matches(';') + .trim(); + + Some(ImportStatement { + module: cleaned.to_string(), + items: Vec::new(), + is_wildcard: cleaned.contains('*'), + }) +} diff --git a/crates/reposcout-ast/src/lib.rs b/crates/reposcout-ast/src/lib.rs new file mode 100644 index 0000000..0d07011 --- /dev/null +++ b/crates/reposcout-ast/src/lib.rs @@ -0,0 +1,36 @@ +//! AST-based code analysis for RepoScout +//! +//! This crate provides Abstract Syntax Tree (AST) parsing and analysis for multiple +//! programming languages using tree-sitter. It enables structural code search by +//! extracting functions, types, and other code elements. +//! +//! # Examples +//! +//! ``` +//! use reposcout_ast::extractors::{get_extractor, AstExtractor}; +//! +//! let code = r#" +//! pub fn hello(name: &str) -> String { +//! format!("Hello, {}!", name) +//! } +//! "#; +//! +//! let extractor = get_extractor("rust").unwrap(); +//! let ast = extractor.extract_all(code).unwrap(); +//! +//! assert_eq!(ast.functions.len(), 1); +//! assert_eq!(ast.functions[0].name, "hello"); +//! ``` + +pub mod error; +pub mod extractors; +pub mod parser; +pub mod query; +pub mod scorer; + +// Re-export commonly used types +pub use error::{AstError, Result}; +pub use extractors::{get_extractor, AstExtractor}; +pub use parser::ParserCache; +pub use query::{parse_query, AstQueryFilters, ParsedQuery}; +pub use scorer::{filter_by_ast, score_ast_match, string_similarity}; diff --git a/crates/reposcout-ast/src/parser.rs b/crates/reposcout-ast/src/parser.rs new file mode 100644 index 0000000..011aee4 --- /dev/null +++ b/crates/reposcout-ast/src/parser.rs @@ -0,0 +1,149 @@ +use crate::error::{AstError, Result}; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::Mutex; +use tree_sitter::{Language, Parser}; + +/// Global parser cache for reuse across requests +pub struct ParserCache { + parsers: Mutex>, + languages: HashMap, +} + +static PARSER_CACHE: Lazy = Lazy::new(|| { + let mut languages = HashMap::new(); + + // Load all tree-sitter grammars + languages.insert("rust".to_string(), tree_sitter_rust::language()); + languages.insert("python".to_string(), tree_sitter_python::language()); + languages.insert("javascript".to_string(), tree_sitter_javascript::language()); + languages.insert( + "typescript".to_string(), + tree_sitter_typescript::language_typescript(), + ); + languages.insert("tsx".to_string(), tree_sitter_typescript::language_tsx()); + languages.insert("go".to_string(), tree_sitter_go::language()); + languages.insert("c".to_string(), tree_sitter_c::language()); + languages.insert("cpp".to_string(), tree_sitter_cpp::language()); + languages.insert("c++".to_string(), tree_sitter_cpp::language()); + + ParserCache { + parsers: Mutex::new(HashMap::new()), + languages, + } +}); + +impl ParserCache { + /// Get the global parser cache instance + pub fn get() -> &'static ParserCache { + &PARSER_CACHE + } + + /// Parse code with language-specific parser + pub fn parse(&self, code: &str, language: &str) -> Result { + let lang_lower = language.to_lowercase(); + + let language = self + .languages + .get(&lang_lower) + .ok_or_else(|| AstError::UnsupportedLanguage(language.to_string()))?; + + let mut parsers = self.parsers.lock().unwrap(); + let parser = parsers.entry(lang_lower.clone()).or_insert_with(|| { + let mut p = Parser::new(); + p.set_language(language) + .expect("Failed to set language for parser"); + p + }); + + parser + .parse(code, None) + .ok_or_else(|| AstError::ParseError("Failed to parse code".to_string())) + } + + /// Detect language from file extension + pub fn detect_language(file_path: &str) -> Option { + let ext = std::path::Path::new(file_path) + .extension()? + .to_str()? + .to_lowercase(); + + match ext.as_str() { + "rs" => Some("rust".to_string()), + "py" | "pyw" => Some("python".to_string()), + "js" | "mjs" | "cjs" => Some("javascript".to_string()), + "ts" => Some("typescript".to_string()), + "tsx" => Some("tsx".to_string()), + "go" => Some("go".to_string()), + "c" | "h" => Some("c".to_string()), + "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "h++" => Some("cpp".to_string()), + _ => None, + } + } + + /// Check if a language is supported + pub fn is_language_supported(language: &str) -> bool { + let lang_lower = language.to_lowercase(); + PARSER_CACHE.languages.contains_key(&lang_lower) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_detection() { + assert_eq!( + ParserCache::detect_language("test.rs"), + Some("rust".to_string()) + ); + assert_eq!( + ParserCache::detect_language("test.py"), + Some("python".to_string()) + ); + assert_eq!( + ParserCache::detect_language("test.js"), + Some("javascript".to_string()) + ); + assert_eq!( + ParserCache::detect_language("test.ts"), + Some("typescript".to_string()) + ); + assert_eq!( + ParserCache::detect_language("test.go"), + Some("go".to_string()) + ); + assert_eq!(ParserCache::detect_language("test.c"), Some("c".to_string())); + assert_eq!( + ParserCache::detect_language("test.cpp"), + Some("cpp".to_string()) + ); + assert_eq!(ParserCache::detect_language("test.txt"), None); + } + + #[test] + fn test_is_language_supported() { + assert!(ParserCache::is_language_supported("rust")); + assert!(ParserCache::is_language_supported("Rust")); + assert!(ParserCache::is_language_supported("python")); + assert!(!ParserCache::is_language_supported("unknown")); + } + + #[test] + fn test_parse_rust() { + let code = "fn main() {}"; + let cache = ParserCache::get(); + let tree = cache.parse(code, "rust").unwrap(); + assert!(tree.root_node().kind() == "source_file"); + } + + #[test] + fn test_parse_unsupported_language() { + let code = "some code"; + let cache = ParserCache::get(); + let result = cache.parse(code, "unknown"); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), AstError::UnsupportedLanguage(_))); + } +} diff --git a/crates/reposcout-ast/src/query.rs b/crates/reposcout-ast/src/query.rs new file mode 100644 index 0000000..f4e4192 --- /dev/null +++ b/crates/reposcout-ast/src/query.rs @@ -0,0 +1,249 @@ +use serde::{Deserialize, Serialize}; + +/// Parsed query with extracted filters +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ParsedQuery { + /// The cleaned search query + pub query: String, + /// Extracted filters + pub filters: AstQueryFilters, +} + +/// AST-based query filters +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +pub struct AstQueryFilters { + /// Function name filter (from "fn:parse" or "--function-name") + pub function_name: Option, + /// Class/struct name filter (from "class:User" or "--class-name") + pub class_name: Option, + /// Async function filter (from "async fn" in query) + pub is_async: Option, + /// Public visibility filter + pub is_public: Option, + /// Generic type filter + pub is_generic: Option, +} + +impl AstQueryFilters { + /// Check if any filters are set + pub fn has_filters(&self) -> bool { + self.function_name.is_some() + || self.class_name.is_some() + || self.is_async.is_some() + || self.is_public.is_some() + || self.is_generic.is_some() + } + + /// Merge with another filter set (other takes precedence) + pub fn merge(&mut self, other: &AstQueryFilters) { + if other.function_name.is_some() { + self.function_name = other.function_name.clone(); + } + if other.class_name.is_some() { + self.class_name = other.class_name.clone(); + } + if other.is_async.is_some() { + self.is_async = other.is_async; + } + if other.is_public.is_some() { + self.is_public = other.is_public; + } + if other.is_generic.is_some() { + self.is_generic = other.is_generic; + } + } +} + +/// Parse a search query to extract AST filters +pub fn parse_query(query: &str) -> ParsedQuery { + let mut filters = AstQueryFilters::default(); + let mut cleaned_parts = Vec::new(); + + // Split query into words + let words: Vec<&str> = query.split_whitespace().collect(); + + for word in words { + // Check for prefix syntax: "fn:name", "class:Name" + if let Some((prefix, value)) = word.split_once(':') { + match prefix.to_lowercase().as_str() { + "fn" | "function" | "func" | "method" => { + filters.function_name = Some(value.to_string()); + continue; + } + "class" | "struct" | "type" | "interface" | "trait" => { + filters.class_name = Some(value.to_string()); + continue; + } + _ => {} + } + } + + // Check for standalone keywords + match word.to_lowercase().as_str() { + "async" => { + filters.is_async = Some(true); + continue; + } + "pub" | "public" => { + filters.is_public = Some(true); + continue; + } + "generic" => { + filters.is_generic = Some(true); + continue; + } + _ => {} + } + + // Keep the word in the query + cleaned_parts.push(word); + } + + // Detect natural language intent + detect_intent(&cleaned_parts, &mut filters); + + ParsedQuery { + query: cleaned_parts.join(" "), + filters, + } +} + +/// Detect intent from natural language queries +fn detect_intent(words: &[&str], filters: &mut AstQueryFilters) { + let query_lower = words.join(" ").to_lowercase(); + + // Detect function-related queries + if query_lower.contains("function") || query_lower.contains("method") { + // Look for "function that" or "function to" patterns + if query_lower.contains("function that") || query_lower.contains("function to") { + // This is a function search + if filters.function_name.is_none() { + // Try to extract the function purpose + // e.g., "function that parses json" -> might want to search for "parse" + for (i, word) in words.iter().enumerate() { + if word.to_lowercase() == "function" && i + 2 < words.len() { + // Skip "that" or "to" + if ["that", "to", "which"].contains(&words[i + 1].to_lowercase().as_str()) + { + if let Some(&verb) = words.get(i + 2) { + // Use the verb as a hint + filters.function_name = Some(verb.to_string()); + } + } + } + } + } + } + } + + // Detect async pattern + if query_lower.contains("async function") + || query_lower.contains("asynchronous") + || query_lower.contains("async fn") + { + filters.is_async = Some(true); + } + + // Detect class/struct pattern + if query_lower.contains("class") || query_lower.contains("struct") { + // This might be a type search + // e.g., "class User" or "struct Config" + for (i, word) in words.iter().enumerate() { + if ["class", "struct", "type"] + .contains(&word.to_lowercase().as_str()) + && i + 1 < words.len() + { + if let Some(&name) = words.get(i + 1) { + if !["that", "which", "with", "for"] + .contains(&name.to_lowercase().as_str()) + { + filters.class_name = Some(name.to_string()); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_function_prefix() { + let parsed = parse_query("fn:parse json data"); + assert_eq!(parsed.filters.function_name, Some("parse".to_string())); + assert_eq!(parsed.query, "json data"); + } + + #[test] + fn test_parse_class_prefix() { + let parsed = parse_query("class:User authentication"); + assert_eq!(parsed.filters.class_name, Some("User".to_string())); + assert_eq!(parsed.query, "authentication"); + } + + #[test] + fn test_parse_async_keyword() { + let parsed = parse_query("async fetch data"); + assert_eq!(parsed.filters.is_async, Some(true)); + assert_eq!(parsed.query, "fetch data"); + } + + #[test] + fn test_parse_combined_filters() { + let parsed = parse_query("async fn:fetch_data public"); + assert_eq!(parsed.filters.function_name, Some("fetch_data".to_string())); + assert_eq!(parsed.filters.is_async, Some(true)); + assert_eq!(parsed.filters.is_public, Some(true)); + assert_eq!(parsed.query, ""); + } + + #[test] + fn test_parse_natural_language_function() { + let parsed = parse_query("function that parses command line arguments"); + // Should detect this is a function query + assert_eq!(parsed.filters.function_name, Some("parses".to_string())); + assert!(parsed.query.contains("command")); + } + + #[test] + fn test_parse_natural_language_async() { + let parsed = parse_query("async function to fetch user data"); + assert_eq!(parsed.filters.is_async, Some(true)); + } + + #[test] + fn test_parse_class_name() { + let parsed = parse_query("class User with email field"); + assert_eq!(parsed.filters.class_name, Some("User".to_string())); + assert!(parsed.query.contains("email")); + } + + #[test] + fn test_no_filters() { + let parsed = parse_query("search for authentication code"); + assert!(!parsed.filters.has_filters()); + assert_eq!(parsed.query, "search for authentication code"); + } + + #[test] + fn test_merge_filters() { + let mut filters1 = AstQueryFilters { + function_name: Some("test".to_string()), + is_async: Some(true), + ..Default::default() + }; + + let filters2 = AstQueryFilters { + function_name: Some("override".to_string()), + class_name: Some("User".to_string()), + ..Default::default() + }; + + filters1.merge(&filters2); + assert_eq!(filters1.function_name, Some("override".to_string())); + assert_eq!(filters1.is_async, Some(true)); + assert_eq!(filters1.class_name, Some("User".to_string())); + } +} diff --git a/crates/reposcout-ast/src/scorer.rs b/crates/reposcout-ast/src/scorer.rs new file mode 100644 index 0000000..b7d7ea9 --- /dev/null +++ b/crates/reposcout-ast/src/scorer.rs @@ -0,0 +1,370 @@ +use crate::query::AstQueryFilters; +use reposcout_core::models::AstMetadata; + +#[cfg(test)] +use reposcout_core::models::{FunctionSignature, TypeDefinition}; + +/// Score AST metadata against query filters +/// Returns a score from 0.0 (no match) to 1.0 (perfect match) +pub fn score_ast_match(ast: &AstMetadata, filters: &AstQueryFilters, query: &str) -> f32 { + if !filters.has_filters() { + // No AST filters, just do text matching on structure + return score_text_match(ast, query); + } + + let mut score = 0.0; + let mut weight_sum = 0.0; + + // Function name match (weight: 1.0) + if let Some(ref fn_name) = filters.function_name { + weight_sum += 1.0; + for func in &ast.functions { + let match_score = string_similarity(&func.name, fn_name); + if match_score > score { + score += match_score; + } + } + } + + // Class/struct name match (weight: 1.0) + if let Some(ref class_name) = filters.class_name { + weight_sum += 1.0; + for type_def in &ast.types { + let match_score = string_similarity(&type_def.name, class_name); + if match_score > score { + score += match_score; + } + } + } + + // Async modifier match (weight: 0.5) + if let Some(is_async) = filters.is_async { + weight_sum += 0.5; + let has_async = ast.functions.iter().any(|f| f.is_async); + if has_async == is_async { + score += 0.5; + } + } + + // Public visibility match (weight: 0.3) + if let Some(is_public) = filters.is_public { + weight_sum += 0.3; + let has_public = ast.functions.iter().any(|f| f.visibility == reposcout_core::models::Visibility::Public) + || ast.types.iter().any(|t| { + t.fields.iter().any(|field| field.visibility == reposcout_core::models::Visibility::Public) + }); + if has_public == is_public { + score += 0.3; + } + } + + // Generic type match (weight: 0.3) + if let Some(is_generic) = filters.is_generic { + weight_sum += 0.3; + let has_generic = ast.functions.iter().any(|f| f.is_generic); + if has_generic == is_generic { + score += 0.3; + } + } + + // Normalize score by total weight + if weight_sum > 0.0 { + score / weight_sum + } else { + 0.0 + } +} + +/// Score based on text similarity to query +fn score_text_match(ast: &AstMetadata, query: &str) -> f32 { + if query.is_empty() { + return 0.5; // Neutral score if no query text + } + + let query_lower = query.to_lowercase(); + let mut best_score = 0.0; + + // Check function names + for func in &ast.functions { + let func_score = fuzzy_contains(&func.name.to_lowercase(), &query_lower); + if func_score > best_score { + best_score = func_score; + } + } + + // Check type names + for type_def in &ast.types { + let type_score = fuzzy_contains(&type_def.name.to_lowercase(), &query_lower); + if type_score > best_score { + best_score = type_score; + } + } + + // Check structure summary + if !ast.structure_summary.is_empty() { + let summary_score = + fuzzy_contains(&ast.structure_summary.to_lowercase(), &query_lower) * 0.5; + if summary_score > best_score { + best_score = summary_score; + } + } + + best_score +} + +/// Calculate string similarity (simple edit distance based) +pub fn string_similarity(s1: &str, s2: &str) -> f32 { + let s1_lower = s1.to_lowercase(); + let s2_lower = s2.to_lowercase(); + + // Exact match + if s1_lower == s2_lower { + return 1.0; + } + + // Contains match + if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) { + return 0.8; + } + + // Prefix match + if s1_lower.starts_with(&s2_lower) || s2_lower.starts_with(&s1_lower) { + return 0.7; + } + + // Calculate Levenshtein-like similarity + let max_len = s1_lower.len().max(s2_lower.len()); + if max_len == 0 { + return 1.0; + } + + let common_chars = s1_lower + .chars() + .filter(|c| s2_lower.contains(*c)) + .count(); + + common_chars as f32 / max_len as f32 +} + +/// Check if text fuzzy contains query +fn fuzzy_contains(text: &str, query: &str) -> f32 { + if text.contains(query) { + return 1.0; + } + + // Check if all query words are in text + let query_words: Vec<&str> = query.split_whitespace().collect(); + let matched_words = query_words.iter().filter(|w| text.contains(*w)).count(); + + if query_words.is_empty() { + 0.0 + } else { + matched_words as f32 / query_words.len() as f32 + } +} + +/// Filter results based on AST query filters +pub fn filter_by_ast(results: Vec<(AstMetadata, f32)>, filters: &AstQueryFilters, min_score: f32) -> Vec<(AstMetadata, f32)> { + results + .into_iter() + .filter(|(ast, score)| { + if *score < min_score { + return false; + } + + // Apply hard filters + if let Some(ref fn_name) = filters.function_name { + if !ast.functions.iter().any(|f| { + let similarity = string_similarity(&f.name, fn_name); + similarity > 0.5 // At least 50% similar + }) { + return false; + } + } + + if let Some(ref class_name) = filters.class_name { + if !ast.types.iter().any(|t| { + let similarity = string_similarity(&t.name, class_name); + similarity > 0.5 + }) { + return false; + } + } + + if let Some(is_async) = filters.is_async { + let has_async = ast.functions.iter().any(|f| f.is_async); + if has_async != is_async { + return false; + } + } + + true + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use reposcout_core::models::{ImportStatement, Parameter, Visibility}; + + fn create_test_function(name: &str, is_async: bool, is_public: bool) -> FunctionSignature { + FunctionSignature { + name: name.to_string(), + parameters: vec![], + return_type: None, + visibility: if is_public { + Visibility::Public + } else { + Visibility::Private + }, + is_async, + is_generic: false, + doc_comment: None, + line_number: 1, + } + } + + fn create_test_type(name: &str) -> TypeDefinition { + TypeDefinition { + name: name.to_string(), + kind: reposcout_core::models::TypeKind::Struct, + fields: vec![], + methods: vec![], + implements: vec![], + doc_comment: None, + line_number: 1, + } + } + + #[test] + fn test_function_name_exact_match() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![create_test_function("parse_args", false, true)], + types: vec![], + imports: vec![], + structure_summary: String::new(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters { + function_name: Some("parse_args".to_string()), + ..Default::default() + }; + + let score = score_ast_match(&ast, &filters, ""); + assert_eq!(score, 1.0, "Exact function name match should score 1.0"); + } + + #[test] + fn test_function_name_partial_match() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![create_test_function("parse_arguments", false, true)], + types: vec![], + imports: vec![], + structure_summary: String::new(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters { + function_name: Some("parse".to_string()), + ..Default::default() + }; + + let score = score_ast_match(&ast, &filters, ""); + assert!(score > 0.7, "Partial match should score > 0.7"); + } + + #[test] + fn test_async_filter() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![create_test_function("fetch", true, true)], + types: vec![], + imports: vec![], + structure_summary: String::new(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters { + is_async: Some(true), + ..Default::default() + }; + + let score = score_ast_match(&ast, &filters, ""); + assert_eq!(score, 1.0, "Async match should score 1.0"); + } + + #[test] + fn test_class_name_match() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![], + types: vec![create_test_type("User")], + imports: vec![], + structure_summary: String::new(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters { + class_name: Some("User".to_string()), + ..Default::default() + }; + + let score = score_ast_match(&ast, &filters, ""); + assert_eq!(score, 1.0, "Exact class name match should score 1.0"); + } + + #[test] + fn test_combined_filters() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![create_test_function("fetch_data", true, true)], + types: vec![], + imports: vec![], + structure_summary: String::new(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters { + function_name: Some("fetch_data".to_string()), + is_async: Some(true), + is_public: Some(true), + ..Default::default() + }; + + let score = score_ast_match(&ast, &filters, ""); + assert_eq!(score, 1.0, "All matching filters should score 1.0"); + } + + #[test] + fn test_text_match_no_filters() { + let ast = AstMetadata { + language: "rust".to_string(), + functions: vec![create_test_function("parse_json", false, true)], + types: vec![], + imports: vec![], + structure_summary: "functions: parse_json".to_string(), + parse_success: true, + parse_error: None, + }; + + let filters = AstQueryFilters::default(); + let score = score_ast_match(&ast, &filters, "parse json"); + assert!(score > 0.5, "Text match should score > 0.5"); + } + + #[test] + fn test_string_similarity() { + assert_eq!(string_similarity("parse", "parse"), 1.0); + assert!(string_similarity("parse_args", "parse") > 0.7); + assert!(string_similarity("Parser", "parser") > 0.9); + } +} diff --git a/crates/reposcout-ast/tests/rust_extraction_test.rs b/crates/reposcout-ast/tests/rust_extraction_test.rs new file mode 100644 index 0000000..193e5eb --- /dev/null +++ b/crates/reposcout-ast/tests/rust_extraction_test.rs @@ -0,0 +1,237 @@ +use reposcout_ast::extractors::get_extractor; +use reposcout_core::models::{TypeKind, Visibility}; + +#[test] +fn test_extract_simple_function() { + let code = r#" + pub fn add(a: i32, b: i32) -> i32 { + a + b + } + "#; + + let extractor = get_extractor("rust").expect("Rust extractor should be available"); + let ast = extractor.extract_all(code).unwrap(); + + assert!(ast.parse_success); + assert_eq!(ast.functions.len(), 1); + + let func = &ast.functions[0]; + assert_eq!(func.name, "add"); + assert_eq!(func.parameters.len(), 2); + assert_eq!(func.parameters[0].name, "a"); + assert_eq!(func.parameters[0].param_type, Some("i32".to_string())); + assert_eq!(func.parameters[1].name, "b"); + assert_eq!(func.parameters[1].param_type, Some("i32".to_string())); + assert_eq!(func.return_type, Some("i32".to_string())); + assert_eq!(func.visibility, Visibility::Public); + assert!(!func.is_async); +} + +#[test] +fn test_extract_async_function() { + let code = r#" + async fn fetch_data() -> Result { + Ok("data".to_string()) + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + let func = &ast.functions[0]; + assert_eq!(func.name, "fetch_data"); + assert!(func.is_async); + assert_eq!(func.return_type, Some("Result".to_string())); +} + +#[test] +fn test_extract_function_with_self() { + let code = r#" + pub fn method(&self, value: u32) -> bool { + true + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 1); + let func = &ast.functions[0]; + assert_eq!(func.name, "method"); + assert_eq!(func.parameters.len(), 2); + assert!(func.parameters[0].name.contains("self")); + assert_eq!(func.parameters[1].name, "value"); +} + +#[test] +fn test_extract_struct() { + let code = r#" + pub struct User { + pub id: u64, + name: String, + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + let type_def = &ast.types[0]; + assert_eq!(type_def.name, "User"); + assert_eq!(type_def.kind, TypeKind::Struct); + assert_eq!(type_def.fields.len(), 2); + + assert_eq!(type_def.fields[0].name, "id"); + assert_eq!(type_def.fields[0].field_type, Some("u64".to_string())); + assert_eq!(type_def.fields[0].visibility, Visibility::Public); + + assert_eq!(type_def.fields[1].name, "name"); + assert_eq!(type_def.fields[1].field_type, Some("String".to_string())); + assert_eq!(type_def.fields[1].visibility, Visibility::Private); +} + +#[test] +fn test_extract_enum() { + let code = r#" + pub enum Status { + Active, + Inactive, + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + let type_def = &ast.types[0]; + assert_eq!(type_def.name, "Status"); + assert_eq!(type_def.kind, TypeKind::Enum); +} + +#[test] +fn test_extract_trait() { + let code = r#" + pub trait Display { + fn display(&self) -> String; + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.types.len(), 1); + let type_def = &ast.types[0]; + assert_eq!(type_def.name, "Display"); + assert_eq!(type_def.kind, TypeKind::Trait); +} + +#[test] +fn test_extract_use_statements() { + let code = r#" + use std::collections::HashMap; + use serde::{Serialize, Deserialize}; + use crate::models::*; + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.imports.len(), 3); + assert!(ast.imports[0].module.contains("HashMap")); + assert!(ast.imports[1].module.contains("Serialize")); + assert!(ast.imports[2].is_wildcard); +} + +#[test] +fn test_extract_multiple_functions() { + let code = r#" + fn helper() {} + + pub async fn process(data: &str) -> Result<()> { + Ok(()) + } + + pub fn validate(input: String) -> bool { + true + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 3); + assert_eq!(ast.functions[0].name, "helper"); + assert_eq!(ast.functions[1].name, "process"); + assert_eq!(ast.functions[2].name, "validate"); +} + +#[test] +fn test_extract_complex_code() { + let code = r#" + use std::fmt; + + pub struct Config { + pub port: u16, + host: String, + } + + impl Config { + pub fn new(port: u16) -> Self { + Self { + port, + host: "localhost".to_string(), + } + } + + pub async fn connect(&self) -> Result { + todo!() + } + } + + pub enum Status { + Connected, + Disconnected, + } + "#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + // Should extract functions from impl block + assert!(ast.functions.len() >= 2); + + // Should extract struct and enum + assert!(ast.types.len() >= 2); + + // Should have imports + assert!(ast.imports.len() >= 1); + + // Should generate summary + assert!(!ast.structure_summary.is_empty()); +} + +#[test] +fn test_unsupported_language() { + let result = get_extractor("cobol"); + assert!(result.is_none()); +} + +#[test] +fn test_line_numbers() { + let code = r#" +fn first() {} + +fn second() {} + +fn third() {} +"#; + + let extractor = get_extractor("rust").unwrap(); + let ast = extractor.extract_all(code).unwrap(); + + assert_eq!(ast.functions.len(), 3); + assert_eq!(ast.functions[0].line_number, 2); + assert_eq!(ast.functions[1].line_number, 4); + assert_eq!(ast.functions[2].line_number, 6); +} diff --git a/crates/reposcout-cli/Cargo.toml b/crates/reposcout-cli/Cargo.toml index 79afa0a..9c62f8b 100644 --- a/crates/reposcout-cli/Cargo.toml +++ b/crates/reposcout-cli/Cargo.toml @@ -17,6 +17,7 @@ reposcout-cache = { path = "../reposcout-cache" } reposcout-tui = { path = "../reposcout-tui" } reposcout-api = { path = "../reposcout-api" } reposcout-semantic = { path = "../reposcout-semantic" } +reposcout-ast = { path = "../reposcout-ast" } clap = { workspace = true } tokio = { workspace = true } diff --git a/crates/reposcout-cli/src/main.rs b/crates/reposcout-cli/src/main.rs index 61e5b31..32200a0 100644 --- a/crates/reposcout-cli/src/main.rs +++ b/crates/reposcout-cli/src/main.rs @@ -69,6 +69,7 @@ enum Commands { /// Search for code within repositories Code { /// Code search query (e.g., "function auth", "class:User") + /// With --semantic: use natural language (e.g., "parse command line arguments") query: String, /// Number of results to show @@ -90,6 +91,30 @@ enum Commands { /// File extension filter (e.g., "rs", "py") #[arg(short = 'e', long)] extension: Option, + + /// Use semantic search (natural language queries) + #[arg(long)] + semantic: bool, + + /// Semantic weight for hybrid ranking (0.0-1.0, default: 0.7) + #[arg(long, default_value = "0.7")] + semantic_weight: f32, + + /// Enable AST-based code analysis + #[arg(long)] + ast: bool, + + /// AST score weight for hybrid ranking (0.0-1.0, default: 0.3) + #[arg(long, default_value = "0.3")] + ast_weight: f32, + + /// Filter by function name (requires --ast) + #[arg(long)] + function_name: Option, + + /// Filter by class/struct name (requires --ast) + #[arg(long)] + class_name: Option, }, /// Show repository details Show { @@ -346,6 +371,12 @@ async fn main() -> anyhow::Result<()> { repo, path, extension, + semantic, + semantic_weight, + ast, + ast_weight, + function_name, + class_name, }) => { search_code( &query, @@ -354,6 +385,12 @@ async fn main() -> anyhow::Result<()> { repo, path, extension, + semantic, + semantic_weight, + ast, + ast_weight, + function_name, + class_name, cli.github_token, cli.gitlab_token, cli.bitbucket_username, @@ -1109,6 +1146,12 @@ async fn search_code( repo: Option, path: Option, extension: Option, + semantic: bool, + semantic_weight: f32, + ast: bool, + ast_weight: f32, + function_name: Option, + class_name: Option, github_token: Option, gitlab_token: Option, bitbucket_username: Option, @@ -1117,8 +1160,64 @@ async fn search_code( use reposcout_api::{GitHubClient, GitLabClient}; use reposcout_core::models::{CodeMatch, CodeSearchResult, Platform}; + // Parse query for AST filters if AST is enabled + let (clean_query, ast_filters) = if ast { + use reposcout_ast::{parse_query, AstQueryFilters}; + + let parsed = parse_query(query); + println!("🌳 AST query parsing enabled"); + println!(" Original query: \"{}\"", query); + println!(" Cleaned query: \"{}\"", parsed.query); + if parsed.filters.has_filters() { + println!(" Detected filters:"); + if let Some(ref fn_name) = parsed.filters.function_name { + println!(" - Function: {}", fn_name); + } + if let Some(ref class_name) = parsed.filters.class_name { + println!(" - Class/Type: {}", class_name); + } + if parsed.filters.is_async == Some(true) { + println!(" - Async functions only"); + } + } + println!(); + + // Merge with CLI flags + let mut filters = parsed.filters; + let cli_filters = AstQueryFilters { + function_name, + class_name, + is_async: None, + is_public: None, + is_generic: None, + }; + filters.merge(&cli_filters); + + (parsed.query, filters) + } else { + use reposcout_ast::AstQueryFilters; + (query.to_string(), AstQueryFilters::default()) + }; + + // Determine search strategy + let (search_query, api_limit) = if semantic { + // For semantic search, extract keywords and fetch more results for re-ranking + use reposcout_semantic::extract_code_keywords; + + let keywords = extract_code_keywords(&clean_query); + println!("🔍 Semantic search mode"); + println!(" Natural language query: \"{}\"", &clean_query); + println!(" Extracted keywords: \"{}\"", keywords); + println!(); + + // Fetch more results (3-5x) for better re-ranking + (keywords, limit * 4) + } else { + (clean_query.clone(), limit) + }; + // Build enhanced query with filters - let mut search_query = query.to_string(); + let mut search_query = search_query; if let Some(lang) = language { search_query.push_str(&format!(" language:{}", lang)); @@ -1143,7 +1242,7 @@ async fn search_code( // Search GitHub if let Some(ref token) = github_token { let github_client = GitHubClient::new(Some(token.clone())); - match github_client.search_code(&search_query, limit as u32).await { + match github_client.search_code(&search_query, api_limit as u32).await { Ok(items) => { for item in items { // Convert GitHub results to our unified format @@ -1157,6 +1256,8 @@ async fn search_code( line_number: 1, context_before: vec![], context_after: vec![], + matched_functions: None, + matched_types: None, } }) .collect(); @@ -1168,6 +1269,8 @@ async fn search_code( line_number: 1, context_before: vec![], context_after: vec![], + matched_functions: None, + matched_types: None, }] } else { matches @@ -1177,11 +1280,13 @@ async fn search_code( platform: Platform::GitHub, repository: item.repository.full_name.clone(), file_path: item.path.clone(), - language: None, // Code search API doesn't return language + language: item.repository.language.clone(), file_url: item.html_url.clone(), - repository_url: item.repository.html_url.clone(), + repository_url: item.repository.html_url.clone() + .unwrap_or_else(|| format!("https://github.com/{}", item.repository.full_name)), matches, - repository_stars: 0, // Code search API doesn't return star count + repository_stars: item.repository.stargazers_count, + ast_metadata: None, }); } tracing::info!("Found {} results from GitHub", all_results.len()); @@ -1212,7 +1317,7 @@ async fn search_code( // Search GitLab if let Some(ref token) = gitlab_token { let gitlab_client = GitLabClient::new(Some(token.clone())); - match gitlab_client.search_code(query, limit as u32).await { + match gitlab_client.search_code(&search_query, api_limit as u32).await { Ok(items) => { // We need to fetch project details for each result // For now, create basic results @@ -1222,6 +1327,8 @@ async fn search_code( line_number: item.startline, context_before: vec![], context_after: vec![], + matched_functions: None, + matched_types: None, }]; all_results.push(CodeSearchResult { @@ -1232,6 +1339,7 @@ async fn search_code( file_url: format!("https://gitlab.com/projects/{}", item.project_id), repository_url: format!("https://gitlab.com/projects/{}", item.project_id), matches, + ast_metadata: None, repository_stars: 0, }); } @@ -1282,25 +1390,88 @@ async fn search_code( return Ok(()); } - // Sort by repository stars - all_results.sort_by(|a, b| b.repository_stars.cmp(&a.repository_stars)); + // Apply semantic/AST re-ranking if enabled + let final_results: Vec<(CodeSearchResult, Option)> = if semantic || ast { + use reposcout_semantic::CodeReranker; + + if ast { + println!("🌳 AST-enhanced semantic search enabled\n"); + } else { + println!("🧠 Re-ranking results using semantic similarity...\n"); + } + + // Initialize semantic re-ranker (with AST if enabled) + let reranker = if ast { + CodeReranker::with_ast("BAAI/bge-small-en-v1.5".to_string()) + } else { + CodeReranker::new("BAAI/bge-small-en-v1.5".to_string()) + }; + reranker.initialize().await?; + + // Re-rank results (with AST filtering and scoring if enabled) + if ast { + let results = reranker + .rerank_with_ast_filters( + query, + all_results, + &ast_filters, + limit, + semantic_weight, + ast_weight, + ) + .await?; + + // Convert 4-tuple to 2-tuple for display + results + .into_iter() + .map(|(result, semantic_score, _ast_score, _hybrid_score)| (result, Some(semantic_score))) + .collect() + } else { + let results = reranker + .rerank_hybrid(query, all_results, limit, semantic_weight) + .await?; + + results + .into_iter() + .map(|(result, semantic_score, _hybrid_score)| (result, Some(semantic_score))) + .collect() + } + } else { + // Sort by repository stars (traditional ranking) + all_results.sort_by(|a, b| b.repository_stars.cmp(&a.repository_stars)); + all_results + .into_iter() + .take(limit) + .map(|result| (result, None)) + .collect() + }; - println!("\n🔍 Found {} code matches:\n", all_results.len()); + println!("\n🔍 Found {} code matches:\n", final_results.len()); - for (i, result) in all_results.iter().take(limit).enumerate() { + for (i, (result, semantic_score)) in final_results.iter().enumerate() { println!("{}. {} ({})", i + 1, result.file_path, result.repository); - println!( - " Platform: {} | ⭐ {}", - result.platform, result.repository_stars - ); + + let mut info_parts = vec![ + format!("Platform: {}", result.platform), + format!("⭐ {}", result.repository_stars), + ]; + + if let Some(score) = semantic_score { + info_parts.push(format!("Similarity: {:.2}", score)); + } + + println!(" {}", info_parts.join(" | ")); + if let Some(lang) = &result.language { println!(" Language: {}", lang); } // Show first match snippet if let Some(first_match) = result.matches.first() { - let snippet = if first_match.content.len() > 150 { - format!("{}...", &first_match.content[..150]) + let snippet = if first_match.content.chars().count() > 150 { + // Use char-aware truncation to avoid splitting multi-byte characters + let truncated: String = first_match.content.chars().take(150).collect(); + format!("{}...", truncated) } else { first_match.content.clone() }; @@ -1400,8 +1571,9 @@ async fn show_trending( println!("{}. {} ({})", i + 1, repo.full_name, repo.platform); if let Some(desc) = &repo.description { - let short_desc = if desc.len() > 100 { - format!("{}...", &desc[..100]) + let short_desc = if desc.chars().count() > 100 { + let truncated: String = desc.chars().take(100).collect(); + format!("{}...", truncated) } else { desc.clone() }; @@ -1562,16 +1734,8 @@ async fn handle_semantic_search( let keyword_results = keyword_engine.search(query).await?; - // Combine with semantic search - let keyword_pairs: Vec<(reposcout_core::models::Repository, f32)> = keyword_results - .into_iter() - .enumerate() - .map(|(i, repo)| { - // Assign decreasing scores based on position - let score = 1.0 - (i as f32 / 100.0).min(0.9); - (repo, score) - }) - .collect(); + // Score keyword results using BM25 + let keyword_pairs = reposcout_semantic::score_keyword_results(keyword_results, query); engine.hybrid_search(query, keyword_pairs, limit).await? } else { diff --git a/crates/reposcout-core/src/models.rs b/crates/reposcout-core/src/models.rs index 87ba9f1..539a0f0 100644 --- a/crates/reposcout-core/src/models.rs +++ b/crates/reposcout-core/src/models.rs @@ -130,6 +130,9 @@ pub struct CodeSearchResult { pub matches: Vec, /// Repository stars (for sorting) pub repository_stars: u32, + /// AST metadata (when AST analysis is enabled) + #[serde(skip_serializing_if = "Option::is_none")] + pub ast_metadata: Option, } /// A code match with line numbers and context @@ -142,4 +145,98 @@ pub struct CodeMatch { /// Optional: surrounding context lines pub context_before: Vec, pub context_after: Vec, + /// Matched functions (when AST analysis is enabled) + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_functions: Option>, + /// Matched types (when AST analysis is enabled) + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_types: Option>, +} + +/// AST metadata extracted from code +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AstMetadata { + /// Detected programming language + pub language: String, + /// Extracted functions/methods + pub functions: Vec, + /// Extracted classes/structs/types + pub types: Vec, + /// Import statements + pub imports: Vec, + /// Overall code structure summary + pub structure_summary: String, + /// Whether AST parsing was successful + pub parse_success: bool, + /// Parse error message if failed + pub parse_error: Option, +} + +/// Function or method signature +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionSignature { + pub name: String, + pub parameters: Vec, + pub return_type: Option, + pub visibility: Visibility, + pub is_async: bool, + pub is_generic: bool, + pub doc_comment: Option, + pub line_number: usize, +} + +/// Function parameter +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Parameter { + pub name: String, + pub param_type: Option, + pub is_optional: bool, +} + +/// Visibility modifier +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum Visibility { + Public, + Private, + Protected, + Internal, +} + +/// Type definition (class, struct, interface, etc.) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeDefinition { + pub name: String, + pub kind: TypeKind, + pub fields: Vec, + pub methods: Vec, + pub implements: Vec, + pub doc_comment: Option, + pub line_number: usize, +} + +/// Type kind +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum TypeKind { + Class, + Struct, + Interface, + Trait, + Enum, + Type, +} + +/// Struct/class field +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Field { + pub name: String, + pub field_type: Option, + pub visibility: Visibility, +} + +/// Import statement +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportStatement { + pub module: String, + pub items: Vec, + pub is_wildcard: bool, } diff --git a/crates/reposcout-semantic/Cargo.toml b/crates/reposcout-semantic/Cargo.toml index 6c23331..eefc086 100644 --- a/crates/reposcout-semantic/Cargo.toml +++ b/crates/reposcout-semantic/Cargo.toml @@ -11,6 +11,7 @@ homepage.workspace = true [dependencies] # Core dependencies reposcout-core = { path = "../reposcout-core" } +reposcout-ast = { path = "../reposcout-ast" } # Embedding and ML # Using version 4 to avoid ort compatibility issues diff --git a/crates/reposcout-semantic/src/ast_analyzer.rs b/crates/reposcout-semantic/src/ast_analyzer.rs new file mode 100644 index 0000000..b0a3d3c --- /dev/null +++ b/crates/reposcout-semantic/src/ast_analyzer.rs @@ -0,0 +1,156 @@ +use crate::error::Result; +use reposcout_ast::{extractors, parser::ParserCache}; +use reposcout_core::models::CodeSearchResult; +use tracing::{debug, warn}; + +/// AST analyzer for enriching code search results +pub struct AstAnalyzer { + enabled: bool, +} + +impl AstAnalyzer { + /// Create a new AST analyzer + pub fn new(enabled: bool) -> Self { + Self { enabled } + } + + /// Enrich a single code search result with AST metadata + pub fn enrich_result(&self, result: &mut CodeSearchResult) -> Result<()> { + if !self.enabled { + return Ok(()); + } + + let language = match &result.language { + Some(lang) => lang.clone(), + None => { + // Try to detect from file path + if let Some(detected) = ParserCache::detect_language(&result.file_path) { + detected + } else { + debug!("Cannot detect language for {}", result.file_path); + return Ok(()); + } + } + }; + + // Check if language is supported + if !ParserCache::is_language_supported(&language) { + debug!("Language not supported for AST: {}", language); + return Ok(()); + } + + // Get extractor for this language + let extractor = match extractors::get_extractor(&language) { + Some(ext) => ext, + None => { + warn!("No AST extractor for language: {}", language); + return Ok(()); + } + }; + + // Extract from first match's content + if let Some(first_match) = result.matches.first() { + match extractor.extract_all(&first_match.content) { + Ok(ast_metadata) => { + debug!( + "Extracted AST for {}: {} functions, {} types", + result.file_path, + ast_metadata.functions.len(), + ast_metadata.types.len() + ); + result.ast_metadata = Some(ast_metadata); + } + Err(e) => { + warn!("AST extraction failed for {}: {}", result.file_path, e); + // Set failed metadata + result.ast_metadata = Some(reposcout_core::models::AstMetadata { + language: language.clone(), + functions: Vec::new(), + types: Vec::new(), + imports: Vec::new(), + structure_summary: String::new(), + parse_success: false, + parse_error: Some(e.to_string()), + }); + } + } + } + + Ok(()) + } + + /// Enrich multiple results in batch + pub fn enrich_results(&self, results: &mut [CodeSearchResult]) -> Result<()> { + for result in results.iter_mut() { + // Continue even if individual enrichment fails + let _ = self.enrich_result(result); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reposcout_core::models::{CodeMatch, Platform}; + + #[test] + fn test_enrich_rust_code() { + let analyzer = AstAnalyzer::new(true); + + let mut result = CodeSearchResult { + platform: Platform::GitHub, + repository: "test/repo".to_string(), + file_path: "test.rs".to_string(), + language: Some("rust".to_string()), + file_url: "https://example.com".to_string(), + repository_url: "https://example.com".to_string(), + repository_stars: 100, + ast_metadata: None, + matches: vec![CodeMatch { + content: "pub fn hello() -> String { String::from(\"hello\") }".to_string(), + line_number: 1, + context_before: Vec::new(), + context_after: Vec::new(), + matched_functions: None, + matched_types: None, + }], + }; + + analyzer.enrich_result(&mut result).unwrap(); + + assert!(result.ast_metadata.is_some()); + let ast = result.ast_metadata.unwrap(); + assert!(ast.parse_success); + assert_eq!(ast.functions.len(), 1); + assert_eq!(ast.functions[0].name, "hello"); + } + + #[test] + fn test_disabled_analyzer() { + let analyzer = AstAnalyzer::new(false); + + let mut result = CodeSearchResult { + platform: Platform::GitHub, + repository: "test/repo".to_string(), + file_path: "test.rs".to_string(), + language: Some("rust".to_string()), + file_url: "https://example.com".to_string(), + repository_url: "https://example.com".to_string(), + repository_stars: 100, + ast_metadata: None, + matches: vec![CodeMatch { + content: "pub fn hello() {}".to_string(), + line_number: 1, + context_before: Vec::new(), + context_after: Vec::new(), + matched_functions: None, + matched_types: None, + }], + }; + + analyzer.enrich_result(&mut result).unwrap(); + + assert!(result.ast_metadata.is_none()); + } +} diff --git a/crates/reposcout-semantic/src/bm25.rs b/crates/reposcout-semantic/src/bm25.rs new file mode 100644 index 0000000..7fd5ced --- /dev/null +++ b/crates/reposcout-semantic/src/bm25.rs @@ -0,0 +1,249 @@ +//! BM25 scoring for keyword-based retrieval +//! +//! Implements the Okapi BM25 ranking function for better keyword matching. + +use reposcout_core::models::Repository; +use std::collections::HashMap; + +/// BM25 scoring parameters +const K1: f32 = 1.2; // Term frequency saturation +const B: f32 = 0.75; // Document length normalization + +/// BM25 scorer for repositories +pub struct BM25Scorer { + /// Document frequencies for each term + doc_frequencies: HashMap, + /// Total number of documents + total_docs: usize, + /// Average document length + avg_doc_len: f32, +} + +impl BM25Scorer { + /// Create a new BM25 scorer from a collection of repositories + pub fn new(repos: &[Repository]) -> Self { + let mut doc_frequencies: HashMap = HashMap::new(); + let mut total_length = 0usize; + + for repo in repos { + let text = repository_to_text(repo); + let tokens = tokenize(&text); + total_length += tokens.len(); + + // Count unique terms in this document + let unique_terms: std::collections::HashSet<_> = tokens.into_iter().collect(); + for term in unique_terms { + *doc_frequencies.entry(term).or_insert(0) += 1; + } + } + + let total_docs = repos.len(); + let avg_doc_len = if total_docs > 0 { + total_length as f32 / total_docs as f32 + } else { + 1.0 + }; + + Self { + doc_frequencies, + total_docs, + avg_doc_len, + } + } + + /// Score a single repository against a query + pub fn score(&self, repo: &Repository, query: &str) -> f32 { + let doc_text = repository_to_text(repo); + let doc_tokens = tokenize(&doc_text); + let query_tokens = tokenize(query); + + if doc_tokens.is_empty() || query_tokens.is_empty() { + return 0.0; + } + + // Count term frequencies in document + let mut term_freqs: HashMap = HashMap::new(); + for token in &doc_tokens { + *term_freqs.entry(token.clone()).or_insert(0) += 1; + } + + let doc_len = doc_tokens.len() as f32; + let mut score = 0.0; + + for term in query_tokens { + let freq = *term_freqs.get(&term).unwrap_or(&0) as f32; + if freq == 0.0 { + continue; + } + + // Calculate IDF + let n = *self.doc_frequencies.get(&term).unwrap_or(&0) as f32; + let idf = ((self.total_docs as f32 - n + 0.5) / (n + 0.5) + 1.0).ln(); + + // Calculate BM25 term score + let numerator = freq * (K1 + 1.0); + let denominator = freq + K1 * (1.0 - B + B * doc_len / self.avg_doc_len); + + score += idf * (numerator / denominator); + } + + score + } + + /// Score multiple repositories and return sorted results + pub fn score_all(&self, repos: &[Repository], query: &str) -> Vec<(Repository, f32)> { + let mut scored: Vec<(Repository, f32)> = repos + .iter() + .map(|repo| { + let score = self.score(repo, query); + (repo.clone(), score) + }) + .collect(); + + // Sort by score descending + scored.sort_by(|a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + scored + } +} + +/// Convert a repository to searchable text +fn repository_to_text(repo: &Repository) -> String { + let mut parts = Vec::new(); + + // Name (weighted by repetition for importance) + let name = repo.full_name.split('/').last().unwrap_or(&repo.full_name); + parts.push(name.to_string()); + parts.push(name.to_string()); // Double weight for name + + // Description + if let Some(desc) = &repo.description { + parts.push(desc.clone()); + } + + // Language + if let Some(lang) = &repo.language { + parts.push(lang.clone()); + } + + // Topics (important for search) + for topic in &repo.topics { + parts.push(topic.clone()); + } + + parts.join(" ") +} + +/// Tokenize text into lowercase terms +fn tokenize(text: &str) -> Vec { + text.to_lowercase() + .split(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty() && s.len() > 1) // Skip single chars + .map(|s| s.to_string()) + .collect() +} + +/// Score keyword results using BM25 +/// +/// Takes pre-fetched keyword results and computes proper BM25 scores +pub fn score_keyword_results(repos: Vec, query: &str) -> Vec<(Repository, f32)> { + if repos.is_empty() { + return Vec::new(); + } + + let scorer = BM25Scorer::new(&repos); + scorer.score_all(&repos, query) +} + +#[cfg(test)] +mod tests { + use super::*; + use reposcout_core::models::Platform; + + fn create_test_repo(name: &str, description: &str, topics: Vec<&str>) -> Repository { + Repository { + platform: Platform::GitHub, + full_name: format!("user/{}", name), + description: Some(description.to_string()), + url: format!("https://github.com/user/{}", name), + homepage_url: None, + stars: 100, + forks: 10, + watchers: 50, + open_issues: 5, + language: Some("Rust".to_string()), + topics: topics.into_iter().map(|s| s.to_string()).collect(), + license: Some("MIT".to_string()), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + pushed_at: chrono::Utc::now(), + size: 1024, + default_branch: "main".to_string(), + is_archived: false, + is_private: false, + health: None, + } + } + + #[test] + fn test_bm25_basic_scoring() { + let repos = vec![ + create_test_repo("logging-lib", "A logging library for applications", vec!["logging", "library"]), + create_test_repo("web-server", "A web server framework", vec!["web", "server"]), + create_test_repo("log-parser", "Parse log files efficiently", vec!["log", "parser"]), + ]; + + let scorer = BM25Scorer::new(&repos); + + // Query for "logging" should rank logging-lib highest + let score1 = scorer.score(&repos[0], "logging"); + let score2 = scorer.score(&repos[1], "logging"); + let score3 = scorer.score(&repos[2], "logging"); + + assert!(score1 > score2, "logging-lib should score higher than web-server for 'logging'"); + assert!(score1 > score3, "logging-lib should score higher than log-parser for 'logging'"); + } + + #[test] + fn test_bm25_multi_term_query() { + let repos = vec![ + create_test_repo("web-framework", "A web framework for building APIs", vec!["web", "api"]), + create_test_repo("api-client", "Client library for APIs", vec!["api", "client"]), + ]; + + let scorer = BM25Scorer::new(&repos); + + // Query for "web api" should consider both terms + let results = scorer.score_all(&repos, "web api"); + + // web-framework should rank first (has both web and api) + assert_eq!(results[0].0.full_name, "user/web-framework"); + } + + #[test] + fn test_score_keyword_results() { + let repos = vec![ + create_test_repo("rust-cli", "Command line tool in Rust", vec!["rust", "cli"]), + create_test_repo("python-cli", "Command line tool in Python", vec!["python", "cli"]), + ]; + + let results = score_keyword_results(repos, "rust cli"); + + assert!(!results.is_empty()); + assert_eq!(results[0].0.full_name, "user/rust-cli"); + assert!(results[0].1 > results[1].1); + } + + #[test] + fn test_tokenize() { + let tokens = tokenize("Hello, World! This is a test-string"); + assert!(tokens.contains(&"hello".to_string())); + assert!(tokens.contains(&"world".to_string())); + assert!(tokens.contains(&"test".to_string())); + assert!(tokens.contains(&"string".to_string())); + // Single char 'a' should be filtered out + assert!(!tokens.contains(&"a".to_string())); + } +} diff --git a/crates/reposcout-semantic/src/code_reranker.rs b/crates/reposcout-semantic/src/code_reranker.rs new file mode 100644 index 0000000..36cb3b1 --- /dev/null +++ b/crates/reposcout-semantic/src/code_reranker.rs @@ -0,0 +1,424 @@ +use crate::embeddings::{cosine_similarity, EmbeddingGenerator}; +use crate::error::Result; +use reposcout_core::models::CodeSearchResult; +use tracing::{debug, info}; + +/// Re-rank code search results using semantic similarity +pub struct CodeReranker { + embedder: EmbeddingGenerator, + ast_analyzer: Option, +} + +impl CodeReranker { + /// Create a new code re-ranker + pub fn new(model_name: String) -> Self { + Self { + embedder: EmbeddingGenerator::new(model_name), + ast_analyzer: None, + } + } + + /// Create a new code re-ranker with AST analysis enabled + pub fn with_ast(model_name: String) -> Self { + Self { + embedder: EmbeddingGenerator::new(model_name), + ast_analyzer: Some(crate::ast_analyzer::AstAnalyzer::new(true)), + } + } + + /// Initialize the embedding model + pub async fn initialize(&self) -> Result<()> { + self.embedder.initialize().await + } + + /// Re-rank code search results based on semantic similarity to query + pub async fn rerank( + &self, + query: &str, + results: Vec, + limit: usize, + ) -> Result> { + if results.is_empty() { + return Ok(Vec::new()); + } + + info!( + "Re-ranking {} code search results for query: {}", + results.len(), + query + ); + + // Generate query embedding + let query_embedding = self.embedder.embed_query(query).await?; + + // Generate embeddings for all code snippets in batch + let snippets: Vec<(&str, Option<&str>, &str, Option<&reposcout_core::models::AstMetadata>)> = results + .iter() + .map(|result| { + let code = if let Some(first_match) = result.matches.first() { + first_match.content.as_str() + } else { + "" + }; + ( + code, + result.language.as_deref(), + result.file_path.as_str(), + result.ast_metadata.as_ref(), + ) + }) + .collect(); + + debug!("Generating embeddings for {} code snippets", snippets.len()); + let code_embeddings = self.embedder.embed_code_snippets(snippets).await?; + + // Calculate similarity scores + let mut scored_results: Vec<(CodeSearchResult, f32)> = results + .into_iter() + .zip(code_embeddings.iter()) + .map(|(result, code_embedding)| { + let similarity = cosine_similarity(&query_embedding, code_embedding); + (result, similarity) + }) + .collect(); + + // Sort by similarity (descending) + scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Limit results + scored_results.truncate(limit); + + info!( + "Re-ranking complete. Top result has similarity: {:.3}", + scored_results + .first() + .map(|(_, score)| *score) + .unwrap_or(0.0) + ); + + Ok(scored_results) + } + + /// Re-rank with hybrid scoring (combining semantic similarity and keyword match) + pub async fn rerank_hybrid( + &self, + query: &str, + results: Vec, + limit: usize, + semantic_weight: f32, + ) -> Result> { + if results.is_empty() { + return Ok(Vec::new()); + } + + info!( + "Hybrid re-ranking {} code search results (semantic_weight: {:.2})", + results.len(), + semantic_weight + ); + + // Get semantic scores + let semantic_results = self.rerank(query, results, usize::MAX).await?; + + // Calculate keyword scores (simple word overlap) + let query_words: std::collections::HashSet = query + .to_lowercase() + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + + let mut hybrid_results: Vec<(CodeSearchResult, f32, f32)> = semantic_results + .into_iter() + .map(|(result, semantic_score)| { + // Calculate keyword score based on word overlap + let code_text = result + .matches + .iter() + .map(|m| m.content.as_str()) + .collect::>() + .join(" "); + + let code_words: std::collections::HashSet = code_text + .to_lowercase() + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + + let overlap = query_words.intersection(&code_words).count(); + let keyword_score = if query_words.is_empty() { + 0.0 + } else { + overlap as f32 / query_words.len() as f32 + }; + + // Combined hybrid score + let hybrid_score = + (semantic_score * semantic_weight) + (keyword_score * (1.0 - semantic_weight)); + + (result, semantic_score, hybrid_score) + }) + .collect(); + + // Sort by hybrid score (descending) + hybrid_results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)); + + // Limit results + hybrid_results.truncate(limit); + + info!( + "Hybrid re-ranking complete. Top result has hybrid score: {:.3}", + hybrid_results + .first() + .map(|(_, _, score)| *score) + .unwrap_or(0.0) + ); + + Ok(hybrid_results) + } + + /// Re-rank with AST enrichment (if enabled) + /// Enriches results with AST metadata before performing hybrid reranking + pub async fn rerank_with_ast( + &self, + query: &str, + mut results: Vec, + limit: usize, + semantic_weight: f32, + ) -> Result> { + // Enrich results with AST if analyzer is available + if let Some(ref analyzer) = self.ast_analyzer { + info!("Enriching {} results with AST metadata", results.len()); + analyzer.enrich_results(&mut results)?; + + let enriched_count = results.iter().filter(|r| r.ast_metadata.is_some()).count(); + info!("{} results successfully enriched with AST", enriched_count); + } + + // Perform hybrid reranking (now with AST-enhanced embeddings) + self.rerank_hybrid(query, results, limit, semantic_weight).await + } + + /// Re-rank with AST filtering and scoring + /// Applies AST filters and incorporates AST similarity scores into ranking + pub async fn rerank_with_ast_filters( + &self, + query: &str, + mut results: Vec, + filters: &reposcout_ast::AstQueryFilters, + limit: usize, + semantic_weight: f32, + ast_weight: f32, + ) -> Result> { + if results.is_empty() { + return Ok(Vec::new()); + } + + // Enrich results with AST if analyzer is available + if let Some(ref analyzer) = self.ast_analyzer { + info!("Enriching {} results with AST metadata", results.len()); + analyzer.enrich_results(&mut results)?; + + let enriched_count = results.iter().filter(|r| r.ast_metadata.is_some()).count(); + let successful_count = results.iter().filter(|r| { + r.ast_metadata.as_ref().map(|ast| ast.parse_success).unwrap_or(false) + }).count(); + info!("{} results enriched with AST ({} successful, {} failed)", + enriched_count, successful_count, enriched_count - successful_count); + } else { + info!("No AST analyzer available - AST scoring will use neutral scores"); + } + + // Apply hard filters if any are set + if filters.has_filters() { + info!("Applying AST filters to {} results", results.len()); + results.retain(|result| { + if let Some(ref ast) = result.ast_metadata { + // Apply hard filters using scorer module + if let Some(ref fn_name) = filters.function_name { + if !ast.functions.iter().any(|f| { + reposcout_ast::scorer::string_similarity(&f.name, fn_name) > 0.5 + }) { + return false; + } + } + + if let Some(ref class_name) = filters.class_name { + if !ast.types.iter().any(|t| { + reposcout_ast::scorer::string_similarity(&t.name, class_name) > 0.5 + }) { + return false; + } + } + + if let Some(is_async) = filters.is_async { + let has_async = ast.functions.iter().any(|f| f.is_async); + if has_async != is_async { + return false; + } + } + + true + } else { + // Keep results without AST metadata - they'll get neutral AST scores + // Don't filter them out just because AST parsing failed + true + } + }); + info!("After filtering: {} results remain", results.len()); + } + + if results.is_empty() { + return Ok(Vec::new()); + } + + info!( + "AST-filtered hybrid re-ranking {} results (semantic: {:.2}, ast: {:.2})", + results.len(), + semantic_weight, + ast_weight + ); + + // Get semantic scores + let semantic_results = self.rerank(query, results, usize::MAX).await?; + + // Calculate keyword scores and AST scores + let query_words: std::collections::HashSet = query + .to_lowercase() + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + + let keyword_weight = 1.0 - semantic_weight - ast_weight; + + let mut scored_results: Vec<(CodeSearchResult, f32, f32, f32)> = semantic_results + .into_iter() + .map(|(result, semantic_score)| { + // Calculate keyword score + let code_text = result + .matches + .iter() + .map(|m| m.content.as_str()) + .collect::>() + .join(" "); + + let code_words: std::collections::HashSet = code_text + .to_lowercase() + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + + let overlap = query_words.intersection(&code_words).count(); + let keyword_score = if query_words.is_empty() { + 0.0 + } else { + overlap as f32 / query_words.len() as f32 + }; + + // Calculate AST score + // If AST metadata is missing or failed to parse, use neutral score (0.5) + // This prevents results from being unfairly penalized when AST extraction fails + let ast_score = if let Some(ref ast) = result.ast_metadata { + if ast.parse_success { + reposcout_ast::scorer::score_ast_match(ast, filters, query) + } else { + // AST parsing failed - use neutral score + 0.5 + } + } else { + // No AST metadata - use neutral score + 0.5 + }; + + // Combined hybrid score with AST + let hybrid_score = (semantic_score * semantic_weight) + + (keyword_score * keyword_weight) + + (ast_score * ast_weight); + + (result, semantic_score, ast_score, hybrid_score) + }) + .collect(); + + // Sort by hybrid score (descending) + scored_results.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal)); + + // Limit results + scored_results.truncate(limit); + + info!( + "AST-filtered hybrid re-ranking complete. Top result has hybrid score: {:.3}", + scored_results + .first() + .map(|(_, _, _, score)| *score) + .unwrap_or(0.0) + ); + + Ok(scored_results) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reposcout_core::models::{CodeMatch, Platform}; + + fn create_test_result(code: &str, language: &str) -> CodeSearchResult { + CodeSearchResult { + platform: Platform::GitHub, + repository: "test/repo".to_string(), + file_path: "src/main.rs".to_string(), + language: Some(language.to_string()), + file_url: "https://github.com/test/repo".to_string(), + repository_url: "https://github.com/test/repo".to_string(), + matches: vec![CodeMatch { + content: code.to_string(), + line_number: 1, + context_before: vec![], + context_after: vec![], + matched_functions: None, + matched_types: None, + }], + repository_stars: 100, + ast_metadata: None, + } + } + + #[tokio::test] + async fn test_reranker_initialization() { + let reranker = CodeReranker::new("BAAI/bge-small-en-v1.5".to_string()); + assert!(reranker.initialize().await.is_ok()); + } + + #[tokio::test] + async fn test_empty_results() { + let reranker = CodeReranker::new("BAAI/bge-small-en-v1.5".to_string()); + reranker.initialize().await.unwrap(); + + let results = reranker + .rerank("test query", vec![], 10) + .await + .unwrap(); + + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_rerank_code_snippets() { + let reranker = CodeReranker::new("BAAI/bge-small-en-v1.5".to_string()); + reranker.initialize().await.unwrap(); + + let results = vec![ + create_test_result("fn add(a: i32, b: i32) -> i32 { a + b }", "Rust"), + create_test_result("fn parse_args() -> Vec { vec![] }", "Rust"), + create_test_result("fn multiply(x: i32, y: i32) -> i32 { x * y }", "Rust"), + ]; + + let reranked = reranker + .rerank("function that parses command line arguments", results, 10) + .await + .unwrap(); + + // The parse_args function should rank highest + assert!(!reranked.is_empty()); + assert!(reranked[0].0.matches[0].content.contains("parse_args")); + } +} diff --git a/crates/reposcout-semantic/src/embeddings.rs b/crates/reposcout-semantic/src/embeddings.rs index 7814e64..8ccc8ab 100644 --- a/crates/reposcout-semantic/src/embeddings.rs +++ b/crates/reposcout-semantic/src/embeddings.rs @@ -210,6 +210,53 @@ impl EmbeddingGenerator { // Generate embedding self.embed_text(&processed_query).await } + + /// Generate embedding for a code snippet + pub async fn embed_code_snippet( + &self, + code: &str, + language: Option<&str>, + file_path: &str, + ast_metadata: Option<&reposcout_core::models::AstMetadata>, + ) -> Result> { + // Preprocess code WITH AST context + let processed_code = crate::preprocessing::preprocess_code_snippet( + code, + language, + file_path, + ast_metadata, + ); + + if processed_code.is_empty() { + return Err(SemanticError::PreprocessingError( + "Empty code after preprocessing".to_string(), + )); + } + + // Generate embedding + self.embed_text(&processed_code).await + } + + /// Generate embeddings for multiple code snippets in batch + pub async fn embed_code_snippets( + &self, + snippets: Vec<(&str, Option<&str>, &str, Option<&reposcout_core::models::AstMetadata>)>, // (code, language, file_path, ast_metadata) + ) -> Result>> { + if snippets.is_empty() { + return Ok(Vec::new()); + } + + // Preprocess all snippets + let processed: Vec = snippets + .iter() + .map(|(code, lang, path, ast)| { + crate::preprocessing::preprocess_code_snippet(code, *lang, path, *ast) + }) + .collect(); + + // Generate embeddings in batch + self.embed_batch(processed).await + } } /// Calculate cosine similarity between two vectors diff --git a/crates/reposcout-semantic/src/lib.rs b/crates/reposcout-semantic/src/lib.rs index e9c7784..6ade062 100644 --- a/crates/reposcout-semantic/src/lib.rs +++ b/crates/reposcout-semantic/src/lib.rs @@ -4,6 +4,9 @@ // and vector similarity search. It enables natural language queries and // finding repositories by use case rather than just keywords. +pub mod ast_analyzer; +pub mod bm25; +pub mod code_reranker; pub mod embeddings; pub mod error; pub mod index; @@ -12,11 +15,14 @@ pub mod preprocessing; pub mod search; // Re-export main types +pub use ast_analyzer::AstAnalyzer; +pub use bm25::{score_keyword_results, BM25Scorer}; +pub use code_reranker::CodeReranker; pub use embeddings::{cosine_similarity, EmbeddingGenerator}; pub use error::{Result, SemanticError}; pub use index::VectorIndex; pub use models::{EmbeddingEntry, IndexStats, SemanticConfig, SemanticSearchResult}; -pub use preprocessing::{preprocess_query, preprocess_repository}; +pub use preprocessing::{extract_code_keywords, preprocess_code_snippet, preprocess_query, preprocess_repository}; pub use search::SemanticSearchEngine; #[cfg(test)] diff --git a/crates/reposcout-semantic/src/models.rs b/crates/reposcout-semantic/src/models.rs index 4f03959..51d4e49 100644 --- a/crates/reposcout-semantic/src/models.rs +++ b/crates/reposcout-semantic/src/models.rs @@ -206,7 +206,7 @@ fn default_enabled() -> bool { } fn default_model() -> String { - "sentence-transformers/all-MiniLM-L6-v2".to_string() + "BAAI/bge-small-en-v1.5".to_string() } fn default_auto_build() -> bool { @@ -218,7 +218,7 @@ fn default_semantic_weight() -> f32 { } fn default_min_similarity() -> f32 { - 0.3 + 0.5 } fn default_max_results() -> usize { diff --git a/crates/reposcout-semantic/src/preprocessing.rs b/crates/reposcout-semantic/src/preprocessing.rs index 5154f25..635bfeb 100644 --- a/crates/reposcout-semantic/src/preprocessing.rs +++ b/crates/reposcout-semantic/src/preprocessing.rs @@ -51,6 +51,101 @@ pub fn preprocess_query(query: &str) -> String { truncate_to_tokens(&cleaned, MAX_TOKENS) } +/// Extract keywords from natural language query for code search +/// Converts "function that parses command line arguments" → "parse command line arguments function" +pub fn extract_code_keywords(query: &str) -> String { + // Common stop words to remove for code search + let stop_words = [ + "a", "an", "the", "that", "which", "who", "what", "where", "when", "how", + "to", "for", "of", "in", "on", "at", "from", "with", "by", "is", "are", + "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", + "did", "will", "would", "should", "could", "may", "might", "must", "can", + "this", "these", "those", "i", "you", "he", "she", "it", "we", "they", + ]; + + let cleaned = clean_text(query); + let words: Vec<&str> = cleaned + .split_whitespace() + .filter(|word| !stop_words.contains(word)) + .collect(); + + words.join(" ") +} + +/// Preprocess code snippet for embedding +/// Combines documentation, signature, and code body +/// NOW ENHANCED: Includes AST information when available +pub fn preprocess_code_snippet( + code: &str, + language: Option<&str>, + file_path: &str, + ast_metadata: Option<&reposcout_core::models::AstMetadata>, +) -> String { + let mut parts = Vec::new(); + + // Add language context + if let Some(lang) = language { + parts.push(lang.to_lowercase()); + } + + // Extract filename (can provide context about purpose) + if let Some(filename) = std::path::Path::new(file_path).file_name() { + if let Some(name) = filename.to_str() { + parts.push(name.replace('_', " ").replace('-', " ")); + } + } + + // NEW: Add AST structure summary + if let Some(ast) = ast_metadata { + if ast.parse_success { + if !ast.structure_summary.is_empty() { + parts.push(ast.structure_summary.clone()); + } + + // Add function names for semantic context + for func in &ast.functions { + parts.push(format!("function {}", func.name)); + if func.is_async { + parts.push("async".to_string()); + } + } + + // Add type names + for type_def in &ast.types { + parts.push(format!("{:?} {}", type_def.kind, type_def.name).to_lowercase()); + } + } + } + + // Clean and add code + let cleaned_code = clean_code(code); + parts.push(cleaned_code); + + // Combine all parts + let combined = parts.join(" "); + truncate_to_tokens(&combined, MAX_TOKENS) +} + +/// Clean code by removing excessive whitespace while preserving structure +fn clean_code(code: &str) -> String { + // Remove excessive blank lines + let lines: Vec<&str> = code + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); + + // Normalize whitespace on each line + let normalized: Vec = lines + .iter() + .map(|line| { + let whitespace = Regex::new(r"\s+").unwrap(); + whitespace.replace_all(line.trim(), " ").to_string() + }) + .collect(); + + normalized.join(" ") +} + /// Clean text by removing special characters and normalizing whitespace fn clean_text(text: &str) -> String { // Remove URLs diff --git a/crates/reposcout-semantic/tests/ast_integration_test.rs b/crates/reposcout-semantic/tests/ast_integration_test.rs new file mode 100644 index 0000000..3d3b4aa --- /dev/null +++ b/crates/reposcout-semantic/tests/ast_integration_test.rs @@ -0,0 +1,138 @@ +use reposcout_core::models::{CodeMatch, CodeSearchResult, Platform}; +use reposcout_semantic::{AstAnalyzer, CodeReranker}; + +#[tokio::test] +async fn test_ast_enrichment_flow() { + // Create a test result with Rust code + let mut result = CodeSearchResult { + platform: Platform::GitHub, + repository: "test/repo".to_string(), + file_path: "src/main.rs".to_string(), + language: Some("rust".to_string()), + file_url: "https://example.com".to_string(), + repository_url: "https://example.com".to_string(), + repository_stars: 100, + ast_metadata: None, + matches: vec![CodeMatch { + content: r#" + pub async fn parse_args(args: Vec) -> Result { + let mut config = Config::default(); + for arg in args { + config.add(arg); + } + Ok(config) + } + + pub struct Config { + pub values: Vec, + } + "# + .to_string(), + line_number: 1, + context_before: Vec::new(), + context_after: Vec::new(), + matched_functions: None, + matched_types: None, + }], + }; + + // Test AST enrichment + let analyzer = AstAnalyzer::new(true); + analyzer.enrich_result(&mut result).unwrap(); + + // Verify AST metadata was extracted + assert!(result.ast_metadata.is_some()); + let ast = result.ast_metadata.as_ref().unwrap(); + + assert!(ast.parse_success, "AST parsing should succeed"); + assert_eq!(ast.language, "rust"); + + // Check functions were extracted + assert_eq!(ast.functions.len(), 1, "Should extract 1 function"); + assert_eq!(ast.functions[0].name, "parse_args"); + assert!(ast.functions[0].is_async, "Should detect async"); + assert_eq!(ast.functions[0].parameters.len(), 1); + + // Check types were extracted + assert_eq!(ast.types.len(), 1, "Should extract 1 type"); + assert_eq!(ast.types[0].name, "Config"); + + println!("✅ AST Enrichment Test Passed!"); + println!(" - Extracted {} functions", ast.functions.len()); + println!(" - Extracted {} types", ast.types.len()); + println!(" - Summary: {}", ast.structure_summary); +} + +#[tokio::test] +async fn test_ast_enhanced_reranking() { + // Create test results with different code snippets + let results = vec![ + create_test_result( + "fn multiply(x: i32, y: i32) -> i32 { x * y }", + "multiply.rs", + ), + create_test_result( + "async fn parse_json(data: &str) -> Result { serde_json::from_str(data) }", + "parser.rs", + ), + create_test_result("fn add(a: i32, b: i32) -> i32 { a + b }", "math.rs"), + ]; + + // Initialize AST-enhanced reranker + let reranker = CodeReranker::with_ast("BAAI/bge-small-en-v1.5".to_string()); + reranker.initialize().await.unwrap(); + + // Re-rank with query about parsing + let query = "parse json data"; + let reranked = reranker + .rerank_with_ast(query, results, 10, 0.7) + .await + .unwrap(); + + // Verify results + assert!(!reranked.is_empty()); + + // The parse_json function should rank highest for this query + let top_result = &reranked[0].0; + assert!( + top_result.file_path.contains("parser"), + "Parser file should rank highest for parse query" + ); + + // Verify AST metadata was added + assert!( + top_result.ast_metadata.is_some(), + "AST metadata should be present" + ); + + let ast = top_result.ast_metadata.as_ref().unwrap(); + if ast.parse_success { + println!("✅ AST-Enhanced Reranking Test Passed!"); + println!(" - Top result: {}", top_result.file_path); + println!(" - Functions found: {}", ast.functions.len()); + if !ast.functions.is_empty() { + println!(" - Top function: {}", ast.functions[0].name); + } + } +} + +fn create_test_result(code: &str, file_path: &str) -> CodeSearchResult { + CodeSearchResult { + platform: Platform::GitHub, + repository: "test/repo".to_string(), + file_path: file_path.to_string(), + language: Some("rust".to_string()), + file_url: "https://example.com".to_string(), + repository_url: "https://example.com".to_string(), + repository_stars: 100, + ast_metadata: None, + matches: vec![CodeMatch { + content: code.to_string(), + line_number: 1, + context_before: Vec::new(), + context_after: Vec::new(), + matched_functions: None, + matched_types: None, + }], + } +} diff --git a/crates/reposcout-tui/Cargo.toml b/crates/reposcout-tui/Cargo.toml index 3855d58..8596a68 100644 --- a/crates/reposcout-tui/Cargo.toml +++ b/crates/reposcout-tui/Cargo.toml @@ -13,6 +13,7 @@ reposcout-api = { path = "../reposcout-api" } reposcout-cache = { path = "../reposcout-cache" } reposcout-deps = { path = "../reposcout-deps" } reposcout-semantic = { path = "../reposcout-semantic" } +reposcout-ast = { path = "../reposcout-ast" } ratatui = { workspace = true } crossterm = { workspace = true } diff --git a/crates/reposcout-tui/src/app.rs b/crates/reposcout-tui/src/app.rs index 390805f..c6d38f5 100644 --- a/crates/reposcout-tui/src/app.rs +++ b/crates/reposcout-tui/src/app.rs @@ -100,12 +100,31 @@ impl SearchFilters { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct CodeSearchFilters { pub language: Option, pub repo: Option, pub path: Option, pub extension: Option, + pub semantic: bool, + pub semantic_weight: f32, + pub ast: bool, + pub ast_weight: f32, +} + +impl Default for CodeSearchFilters { + fn default() -> Self { + Self { + language: None, + repo: None, + path: None, + extension: None, + semantic: false, + semantic_weight: 0.7, + ast: false, + ast_weight: 0.3, + } + } } impl CodeSearchFilters { @@ -940,6 +959,9 @@ impl App { self.show_code_filters = !self.show_code_filters; if self.show_code_filters { self.code_filter_cursor = 0; + self.input_mode = InputMode::Filtering; + } else { + self.input_mode = InputMode::Normal; } } @@ -966,6 +988,63 @@ impl App { } } + /// Enter editing mode for code filter + pub fn enter_editing_code_filter_mode(&mut self) { + self.input_mode = InputMode::EditingFilter; + // Load current filter value into edit buffer + self.code_filter_edit_buffer = match self.code_filter_cursor { + 0 => self.code_filters.language.clone().unwrap_or_default(), + 1 => self.code_filters.repo.clone().unwrap_or_default(), + 2 => self.code_filters.path.clone().unwrap_or_default(), + 3 => self.code_filters.extension.clone().unwrap_or_default(), + _ => String::new(), + }; + } + + /// Save code filter edit + pub fn save_code_filter_edit(&mut self) { + // Save the edit buffer to the actual filter + match self.code_filter_cursor { + 0 => { + self.code_filters.language = if self.code_filter_edit_buffer.is_empty() { + None + } else { + Some(self.code_filter_edit_buffer.clone()) + }; + } + 1 => { + self.code_filters.repo = if self.code_filter_edit_buffer.is_empty() { + None + } else { + Some(self.code_filter_edit_buffer.clone()) + }; + } + 2 => { + self.code_filters.path = if self.code_filter_edit_buffer.is_empty() { + None + } else { + Some(self.code_filter_edit_buffer.clone()) + }; + } + 3 => { + self.code_filters.extension = if self.code_filter_edit_buffer.is_empty() { + None + } else { + Some(self.code_filter_edit_buffer.clone()) + }; + } + _ => {} + } + self.code_filter_edit_buffer.clear(); + self.input_mode = InputMode::Filtering; + } + + /// Cancel code filter edit + pub fn cancel_code_filter_edit(&mut self) { + self.code_filter_edit_buffer.clear(); + self.input_mode = InputMode::Filtering; + } + /// Toggle code preview mode (Code/Raw/FileInfo) pub fn toggle_code_preview_mode(&mut self) { self.code_preview_mode = match self.code_preview_mode { diff --git a/crates/reposcout-tui/src/code_ui.rs b/crates/reposcout-tui/src/code_ui.rs index 1dad566..0d0e86d 100644 --- a/crates/reposcout-tui/src/code_ui.rs +++ b/crates/reposcout-tui/src/code_ui.rs @@ -194,8 +194,9 @@ pub fn render_code_results_list(frame: &mut Frame, app: &App, area: Rect) { // Get preview of first match let preview = if let Some(first_match) = result.matches.first() { let content = first_match.content.trim(); - let truncated = if content.len() > 60 { - format!("{}...", &content[..60]) + let truncated = if content.chars().count() > 60 { + let chars: String = content.chars().take(60).collect(); + format!("{}...", chars) } else { content.to_string() }; @@ -261,9 +262,14 @@ pub fn render_code_results_list(frame: &mut Frame, app: &App, area: Rect) { Style::default() .bg(Color::Rgb(60, 60, 80)) .add_modifier(Modifier::BOLD), - ); + ) + .highlight_symbol("▸ "); + + // Create a ListState and sync it with the current selection + let mut list_state = ratatui::widgets::ListState::default(); + list_state.select(Some(app.code_selected_index)); - frame.render_widget(list, list_area); + frame.render_stateful_widget(list, list_area, &mut list_state); } /// Render code filter panel @@ -290,7 +296,7 @@ fn render_code_filter_panel(frame: &mut Frame, app: &App, area: Rect) { .add_modifier(Modifier::BOLD), ), Span::styled( - "(↑↓: navigate | Enter: edit | Del: clear | F: close)", + "(↑↓: navigate | Enter: edit | Del: clear | S: semantic | A: AST | F: close)", Style::default().fg(Color::DarkGray), ), ]), @@ -309,7 +315,13 @@ fn render_code_filter_panel(frame: &mut Frame, app: &App, area: Rect) { Style::default().fg(Color::Cyan) }; - let value_display = if value.is_empty() { "" } else { value }; + let value_display = if is_editing { + &app.code_filter_edit_buffer + } else if value.is_empty() { + "" + } else { + value + }; let value_style = if is_editing { Style::default().fg(Color::Black).bg(Color::Yellow) } else if is_active { @@ -331,6 +343,67 @@ fn render_code_filter_panel(frame: &mut Frame, app: &App, area: Rect) { ])); } + // Add semantic mode indicator + lines.push(Line::from("")); + lines.push(Line::from(vec![ + Span::styled(" Mode: ", Style::default().fg(Color::Cyan)), + Span::styled( + if app.code_filters.semantic { + "🧠 Semantic (natural language)" + } else { + "🔍 Exact (text matching)" + }, + if app.code_filters.semantic { + Style::default() + .fg(Color::Magenta) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(Color::White) + }, + ), + ])); + + if app.code_filters.semantic { + lines.push(Line::from(vec![ + Span::styled(" Weight: ", Style::default().fg(Color::Cyan)), + Span::styled( + format!("{:.1}%", app.code_filters.semantic_weight * 100.0), + Style::default().fg(Color::Green), + ), + Span::styled(" semantic", Style::default().fg(Color::DarkGray)), + ])); + } + + // Add AST mode indicator + lines.push(Line::from(vec![ + Span::styled(" AST: ", Style::default().fg(Color::Cyan)), + Span::styled( + if app.code_filters.ast { + "🌳 Enabled (structure-aware)" + } else { + "❌ Disabled" + }, + if app.code_filters.ast { + Style::default() + .fg(Color::Green) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(Color::DarkGray) + }, + ), + ])); + + if app.code_filters.ast { + lines.push(Line::from(vec![ + Span::styled(" Weight: ", Style::default().fg(Color::Cyan)), + Span::styled( + format!("{:.1}%", app.code_filters.ast_weight * 100.0), + Style::default().fg(Color::Green), + ), + Span::styled(" AST", Style::default().fg(Color::DarkGray)), + ])); + } + let paragraph = Paragraph::new(lines).block( Block::default() .borders(Borders::ALL) diff --git a/crates/reposcout-tui/src/discovery_ui.rs b/crates/reposcout-tui/src/discovery_ui.rs index 0c4336d..75fc486 100644 --- a/crates/reposcout-tui/src/discovery_ui.rs +++ b/crates/reposcout-tui/src/discovery_ui.rs @@ -3,7 +3,7 @@ use ratatui::{ layout::{Alignment, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, - widgets::{Block, Borders, List, ListItem, Paragraph}, + widgets::{Block, Borders, List, ListItem, ListState, Paragraph}, Frame, }; @@ -228,34 +228,33 @@ fn render_topics(frame: &mut Frame, app: &App, area: Rect) { Line::from(""), ])]; - for (i, (topic, name)) in topics.iter().enumerate() { - let is_selected = i == app.discovery_cursor; - - let style = if is_selected { - Style::default() - .fg(Color::Yellow) - .add_modifier(Modifier::BOLD) - } else { - Style::default().fg(Color::White) - }; - - let indicator = if is_selected { "▶ " } else { " " }; - + for (topic, name) in topics.iter() { items.push(ListItem::new(vec![Line::from(vec![ - Span::styled(format!("{}{}", indicator, name), style), + Span::styled(format!(" {}", name), Style::default().fg(Color::White)), Span::raw(" "), Span::styled(format!("({})", topic), Style::default().fg(Color::DarkGray)), ])])); } - let list = List::new(items).block( - Block::default() - .borders(Borders::ALL) - .title("Topics") - .border_style(Style::default().fg(Color::Cyan)), - ); + let list = List::new(items) + .block( + Block::default() + .borders(Borders::ALL) + .title("Topics") + .border_style(Style::default().fg(Color::Cyan)), + ) + .highlight_style( + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD), + ) + .highlight_symbol("▶ "); - frame.render_widget(list, area); + // Create list state with selection (offset by 1 for header item) + let mut list_state = ListState::default(); + list_state.select(Some(app.discovery_cursor + 1)); + + frame.render_stateful_widget(list, area, &mut list_state); } fn render_awesome_lists(frame: &mut Frame, app: &App, area: Rect) { @@ -276,35 +275,37 @@ fn render_awesome_lists(frame: &mut Frame, app: &App, area: Rect) { Line::from(""), ])]; - for (i, (repo, name)) in awesome_lists.iter().enumerate() { - let is_selected = i == app.discovery_cursor; - - let style = if is_selected { - Style::default() - .fg(Color::Yellow) - .add_modifier(Modifier::BOLD) - } else { - Style::default().fg(Color::White) - }; - - let indicator = if is_selected { "▶ " } else { " " }; - + for (repo, name) in awesome_lists.iter() { items.push(ListItem::new(vec![ - Line::from(vec![Span::styled(format!("{}{}", indicator, name), style)]), + Line::from(vec![Span::styled( + format!(" {}", name), + Style::default().fg(Color::White), + )]), Line::from(vec![ - Span::raw(" "), + Span::raw(" "), Span::styled(*repo, Style::default().fg(Color::DarkGray)), ]), Line::from(""), ])); } - let list = List::new(items).block( - Block::default() - .borders(Borders::ALL) - .title("Awesome Lists") - .border_style(Style::default().fg(Color::Cyan)), - ); + let list = List::new(items) + .block( + Block::default() + .borders(Borders::ALL) + .title("Awesome Lists") + .border_style(Style::default().fg(Color::Cyan)), + ) + .highlight_style( + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD), + ) + .highlight_symbol("▶ "); - frame.render_widget(list, area); + // Create list state with selection (offset by 1 for header item) + let mut list_state = ListState::default(); + list_state.select(Some(app.discovery_cursor + 1)); + + frame.render_stateful_widget(list, area, &mut list_state); } diff --git a/crates/reposcout-tui/src/runner.rs b/crates/reposcout-tui/src/runner.rs index 13cae81..a1f9181 100644 --- a/crates/reposcout-tui/src/runner.rs +++ b/crates/reposcout-tui/src/runner.rs @@ -152,8 +152,9 @@ where { "Network error. Check your connection." .to_string() - } else if error_str.len() > 100 { - format!("{}...", &error_str[..100]) + } else if error_str.chars().count() > 100 { + let truncated: String = error_str.chars().take(100).collect(); + format!("{}...", truncated) } else { error_str }; @@ -168,14 +169,31 @@ where } SearchMode::Code => { // Perform code search - let query = app.get_code_search_query(); + let original_query = app.search_input.clone(); + let semantic = app.code_filters.semantic; + let semantic_weight = app.code_filters.semantic_weight; + let ast = app.code_filters.ast; + let ast_weight = app.code_filters.ast_weight; + + // Determine search strategy + let (search_query, api_limit) = if semantic { + use reposcout_semantic::extract_code_keywords; + let keywords = extract_code_keywords(&original_query); + tracing::info!("Original query: '{}', Extracted keywords: '{}'", original_query, keywords); + (app.code_filters.build_query(&keywords), 80) + } else { + (app.get_code_search_query(), 30) + }; + + tracing::info!("Code search query: '{}', API limit: {}", search_query, api_limit); // Search GitHub and GitLab for code let mut all_results = Vec::new(); // Search GitHub - match github_client.search_code(&query, 30).await { + match github_client.search_code(&search_query, api_limit).await { Ok(items) => { + tracing::info!("GitHub API returned {} code results", items.len()); for item in items { use reposcout_core::models::{ CodeMatch, CodeSearchResult, Platform, @@ -189,6 +207,8 @@ where line_number: 1, context_before: vec![], context_after: vec![], + matched_functions: None, + matched_types: None, }) .collect(); @@ -201,6 +221,8 @@ where line_number: 1, context_before: vec![], context_after: vec![], + matched_functions: None, + matched_types: None, }] } else { matches @@ -213,14 +235,16 @@ where .full_name .clone(), file_path: item.path.clone(), - language: None, // Code search API doesn't return language + language: item.repository.language.clone(), file_url: item.html_url.clone(), repository_url: item .repository .html_url - .clone(), + .clone() + .unwrap_or_else(|| format!("https://github.com/{}", item.repository.full_name)), matches, - repository_stars: 0, // Code search API doesn't return star count + repository_stars: item.repository.stargazers_count, + ast_metadata: None, }); } } @@ -245,8 +269,9 @@ where .to_string() } else { // Truncate long error messages - let short_msg = if error_str.len() > 100 { - format!("{}...", &error_str[..100]) + let short_msg = if error_str.chars().count() > 100 { + let truncated: String = error_str.chars().take(100).collect(); + format!("{}...", truncated) } else { error_str }; @@ -262,16 +287,64 @@ where } } - // Sort by stars - all_results.sort_by(|a, b| { - b.repository_stars.cmp(&a.repository_stars) - }); + tracing::info!("Total code results before reranking: {}", all_results.len()); + + // Apply semantic/AST re-ranking if enabled + let final_results = if (semantic || ast) && !all_results.is_empty() { + use reposcout_semantic::CodeReranker; + + if ast { + tracing::info!("Applying AST-enhanced re-ranking (semantic: {}, ast: {})", semantic_weight, ast_weight); + } else { + tracing::info!("Applying semantic re-ranking with weight: {}", semantic_weight); + } + + match (async { + let reranker = if ast { + CodeReranker::with_ast("BAAI/bge-small-en-v1.5".to_string()) + } else { + CodeReranker::new("BAAI/bge-small-en-v1.5".to_string()) + }; + reranker.initialize().await?; + + if ast { + use reposcout_ast::parse_query; + let parsed = parse_query(&original_query); + let filters = parsed.filters; + let results = reranker.rerank_with_ast_filters(&original_query, all_results.clone(), &filters, 30, semantic_weight, ast_weight).await?; + Ok::, anyhow::Error>(results.into_iter().map(|(result, _, _, _)| result).collect()) + } else { + let results = reranker.rerank_hybrid(&original_query, all_results.clone(), 30, semantic_weight).await?; + Ok::, anyhow::Error>(results.into_iter().map(|(result, _, _)| result).collect()) + } + }).await { + Ok(results) => { + tracing::info!("Re-ranking returned {} results", results.len()); + results + } + Err(e) => { + tracing::warn!("Re-ranking failed: {}", e); + app.error_message = Some(format!("Search error: {}", e)); + // Fall back to star-based sorting + all_results.sort_by(|a, b| b.repository_stars.cmp(&a.repository_stars)); + all_results + } + } + } else { + // Sort by stars (traditional ranking) + all_results.sort_by(|a, b| { + b.repository_stars.cmp(&a.repository_stars) + }); + all_results + }; + + tracing::info!("Final code results count: {}", final_results.len()); - if all_results.is_empty() { + if final_results.is_empty() { app.error_message = Some("No code matches found. Try a different search query.".to_string()); } - app.set_code_results(all_results); + app.set_code_results(final_results); app.loading = false; } SearchMode::Semantic => { @@ -295,15 +368,8 @@ where Ok(engine) => { match engine.initialize().await { Ok(_) => { - // Convert to format expected by hybrid_search - let keyword_pairs: Vec<(reposcout_core::models::Repository, f32)> = keyword_results - .into_iter() - .enumerate() - .map(|(i, repo)| { - let score = 1.0 - (i as f32 / 100.0).min(0.9); - (repo, score) - }) - .collect(); + // Score keyword results using BM25 + let keyword_pairs = reposcout_semantic::score_keyword_results(keyword_results, &query); match engine .hybrid_search( @@ -382,40 +448,87 @@ where _ => {} }, InputMode::Filtering => match key.code { - KeyCode::Esc => { + KeyCode::Esc | KeyCode::Char('F') => { + // Handle both Esc and F to close filters + if app.show_code_filters { + app.show_code_filters = false; + } app.enter_normal_mode(); } KeyCode::Tab | KeyCode::Down | KeyCode::Char('j') => { - app.next_filter(); + if app.show_code_filters { + app.next_code_filter(); + } else { + app.next_filter(); + } } KeyCode::Up | KeyCode::Char('k') => { - app.previous_filter(); + if app.show_code_filters { + app.previous_code_filter(); + } else { + app.previous_filter(); + } } KeyCode::Delete | KeyCode::Char('d') => { - app.clear_current_filter(); + if app.show_code_filters { + app.clear_current_code_filter(); + } else { + app.clear_current_filter(); + } } KeyCode::Enter => { // Enter edit mode for this filter - app.enter_editing_filter_mode(); + if app.show_code_filters { + app.enter_editing_code_filter_mode(); + } else { + app.enter_editing_filter_mode(); + } + } + KeyCode::Char('s') => { + if app.show_code_filters { + // Toggle semantic mode in code filters + app.code_filters.semantic = !app.code_filters.semantic; + } else if app.filter_cursor == 4 { + // Cycle sort options with 's' key in repo filters + app.cycle_sort(); + } } - KeyCode::Char('s') if app.filter_cursor == 4 => { - // Cycle sort options with 's' key - app.cycle_sort(); + KeyCode::Char('a') => { + if app.show_code_filters { + // Toggle AST mode in code filters + app.code_filters.ast = !app.code_filters.ast; + } } _ => {} }, InputMode::EditingFilter => match key.code { KeyCode::Enter => { - app.save_filter_edit(); + if app.show_code_filters { + app.save_code_filter_edit(); + } else { + app.save_filter_edit(); + } } KeyCode::Esc => { - app.cancel_filter_edit(); + if app.show_code_filters { + app.cancel_code_filter_edit(); + } else { + app.cancel_filter_edit(); + } } KeyCode::Char(c) => { - app.filter_edit_buffer.push(c); + if app.show_code_filters { + app.code_filter_edit_buffer.push(c); + } else { + app.filter_edit_buffer.push(c); + } } KeyCode::Backspace => { - app.filter_edit_buffer.pop(); + if app.show_code_filters { + app.code_filter_edit_buffer.pop(); + } else { + app.filter_edit_buffer.pop(); + } } _ => {} }, @@ -523,14 +636,8 @@ where Ok(engine) => { match engine.initialize().await { Ok(_) => { - let keyword_pairs: Vec<(reposcout_core::models::Repository, f32)> = keyword_results - .into_iter() - .enumerate() - .map(|(i, repo)| { - let score = 1.0 - (i as f32 / 100.0).min(0.9); - (repo, score) - }) - .collect(); + // Score keyword results using BM25 + let keyword_pairs = reposcout_semantic::score_keyword_results(keyword_results, &query_str); match engine .hybrid_search( @@ -802,6 +909,34 @@ where // Force full redraw terminal.clear()?; } + KeyCode::Char('S') => { + // Toggle semantic search mode (only in code search mode) + if app.search_mode == SearchMode::Code { + app.code_filters.semantic = !app.code_filters.semantic; + // Show feedback message + let status = if app.code_filters.semantic { + "Semantic search ENABLED (natural language queries)" + } else { + "Semantic search DISABLED (exact text matching)" + }; + app.set_temp_error(format!("{} (Press Esc to dismiss)", status)); + terminal.clear()?; + } + } + KeyCode::Char('A') => { + // Toggle AST search mode (only in code search mode) + if app.search_mode == SearchMode::Code { + app.code_filters.ast = !app.code_filters.ast; + // Show feedback message + let status = if app.code_filters.ast { + "AST search ENABLED (structure-aware)" + } else { + "AST search DISABLED" + }; + app.set_temp_error(format!("{} (Press Esc to dismiss)", status)); + terminal.clear()?; + } + } KeyCode::Char('m') => { // Mark selected notification as read (only in notification mode) if app.search_mode == SearchMode::Notifications { @@ -1285,6 +1420,45 @@ where } } } + KeyCode::Char('F') => { + if app.search_mode == SearchMode::Code { + // Toggle code filters in Code mode + app.toggle_code_filters(); + } else if app.search_mode == SearchMode::Notifications { + // Toggle all/unread filter in notification mode + app.toggle_notification_filter(); + + // Refresh notifications with new filter + app.notifications_loading = true; + terminal.draw(|f| crate::ui::render(f, &mut app))?; + + match github_client + .get_notifications( + app.notifications_show_all, + app.notifications_participating, + 50, + ) + .await + { + Ok(notifications) => { + app.notifications = notifications; + app.notifications_selected_index = 0; + app.notifications_loading = false; + app.error_message = None; + } + Err(e) => { + app.error_message = Some(format!( + "Failed to fetch notifications: {}", + e + )); + app.notifications_loading = false; + } + } + } else { + // Toggle repo filters in other modes + app.toggle_filters(); + } + } KeyCode::Char('b') => { // Toggle bookmark for current repository if let Some(repo) = app.selected_repository() { @@ -1438,17 +1612,6 @@ where } } } - KeyCode::Char('F') => { - // Toggle filters based on search mode - if app.search_mode == SearchMode::Code { - app.toggle_code_filters(); - } else { - app.toggle_filters(); - if app.show_filters { - app.enter_filter_mode(); - } - } - } KeyCode::Tab => { // Tab cycles through preview tabs/modes based on search mode if app.search_mode == SearchMode::Discovery {