Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ serde_json = "1"
toml = "0"
env_logger = "0"
reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] }
regex = "1.11.1"

[dev-dependencies]
tempfile = "3"
Expand Down
182 changes: 173 additions & 9 deletions src/config/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::config::{api::Api, resolve_config_path};

const PROMPT_FILE: &str = "prompts.toml";
const CONVERSATION_FILE: &str = "conversation.toml";
const CONVERSATIONS_PATH: &str = "saved_conversations";

#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct Prompt {
Expand Down Expand Up @@ -100,15 +101,62 @@ pub fn conversation_file_path() -> PathBuf {
resolve_config_path().join(CONVERSATION_FILE)
}

pub fn get_last_conversation_as_prompt() -> Prompt {
let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
conversation_file_path(),
error
)
});
toml::from_str(&content).expect("failed to load the conversation file")
// Get the path to the conversations directory
pub fn conversations_path() -> PathBuf {
resolve_config_path().join(CONVERSATIONS_PATH)
}

// Get the path to a specific conversation file
pub fn named_conversation_path(name: &str) -> PathBuf {
conversations_path().join(format!("{}.toml", name))
}

// Get the last conversation as a prompt, if it exists
pub fn get_last_conversation_as_prompt(name: Option<&str>) -> Option<Prompt> {
if let Some(name) = name {
let named_path = named_conversation_path(name);
if !named_path.exists() {
return None;
}
let content = fs::read_to_string(named_path)
.unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
named_conversation_path(name),
error
)
});
Some(toml::from_str(&content).expect("failed to load the conversation file"))
} else {
let path = conversation_file_path();
if !path.exists() {
return None;
}
let content = fs::read_to_string(path)
.unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
conversation_file_path(),
error
)
});
Some(toml::from_str(&content).expect("failed to load the conversation file"))
}
}

pub fn save_conversation(prompt: &Prompt, name: Option<&str>) -> std::io::Result<()> {
let toml_string = toml::to_string(prompt).expect("Failed to serialize prompt");

// Always save to conversation.toml
fs::write(conversation_file_path(), &toml_string)?;

// If name is provided, also save to named conversation file
if let Some(name) = name {
fs::create_dir_all(conversations_path())?;
fs::write(named_conversation_path(name), &toml_string)?;
}

Ok(())
}

pub(super) fn generate_prompts_file() -> std::io::Result<()> {
Expand Down Expand Up @@ -136,3 +184,119 @@ pub fn get_prompts() -> HashMap<String, Prompt> {
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path(), error));
toml::from_str(&content).expect("could not parse prompt file content")
}

#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
use crate::config::prompt::Prompt;
use serial_test::serial;

fn setup() -> tempfile::TempDir {
let temp_dir = tempdir().unwrap();
std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path());
temp_dir
}

fn create_test_prompt() -> Prompt {
let mut prompt = Prompt::default();
prompt.messages = vec![(Message::user("test"))];
prompt
}

#[test]
#[serial]
fn test_get_and_save_default_conversation() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();

// Test saving conversation
save_conversation(&test_prompt, None).unwrap();
assert!(conversation_file_path().exists());

// Test retrieving conversation
let loaded_prompt = get_last_conversation_as_prompt(None).unwrap();
assert_eq!(loaded_prompt, test_prompt);
}

#[test]
#[serial]
fn test_get_and_save_named_conversation() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();
let conv_name = "test_conversation";

// Test saving named conversation
save_conversation(&test_prompt, Some(conv_name)).unwrap();
assert!(named_conversation_path(conv_name).exists());
assert!(conversation_file_path().exists()); // Should also save to default location

// Test retrieving named conversation
let loaded_prompt = get_last_conversation_as_prompt(Some(conv_name)).unwrap();
assert_eq!(loaded_prompt, test_prompt);
}

#[test]
#[serial]
fn test_nonexistent_conversation() {
let _temp_dir = setup();

// Test getting nonexistent default conversation
assert!(get_last_conversation_as_prompt(None).is_none());

// Test getting nonexistent named conversation
assert!(get_last_conversation_as_prompt(Some("nonexistent")).is_none());
}

#[test]
#[serial]
fn test_conversation_file_contents() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();
let conv_name = "test_conversation";

// Save conversation
save_conversation(&test_prompt, Some(conv_name)).unwrap();

// Verify default and named files have identical content
let default_content = fs::read_to_string(conversation_file_path()).unwrap();
let named_content = fs::read_to_string(named_conversation_path(conv_name)).unwrap();
assert_eq!(default_content, named_content);

// Verify content can be parsed back to original prompt
let parsed_prompt: Prompt = toml::from_str(&default_content).unwrap();
assert_eq!(parsed_prompt, test_prompt);
}

#[test]
#[serial]
fn test_generate_prompts_file() {
let _temp_dir = setup();

// Test file generation
generate_prompts_file().unwrap();
assert!(prompts_path().exists());

// Verify file is valid TOML and contains expected content
let content = fs::read_to_string(prompts_path()).unwrap();
let prompts: HashMap<String, Prompt> = toml::from_str(&content).unwrap();
assert!(!prompts.is_empty());
}

#[test]
#[serial]
fn test_get_prompts() {
let _temp_dir = setup();

// Generate prompts file
generate_prompts_file().unwrap();

// Test loading prompts
let prompts = get_prompts();
assert!(!prompts.is_empty());

// Verify at least one default prompt exists
assert!(prompts.contains_key("default"));
}
}
Loading