From 216a11b9db4e8ed21db93aec662aac871338c7a1 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 10 May 2026 10:59:43 +0200 Subject: [PATCH] feat(evil-keys): add keybinding crate with trie dispatch, count prefix, and timeout Plug'n'play modal keybinding system inspired by Doom Emacs + Evil mode. Generic over consumer Action type. Core: Key parser ("C-d", "SPC"), trie-based sequence matching with conflict detection, count prefix (5j), timeout tracking, which-key introspection, and multi-mode dispatch. 78 unit tests covering key parsing, trie conflicts, dispatch state machine, count accumulation, timeout expiry, and which-key generation. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/claude-agent) Co-authored-by: Sisyphus --- crates/evil-keys/Cargo.toml | 11 + crates/evil-keys/src/count.rs | 114 ++++++ crates/evil-keys/src/dispatch.rs | 615 ++++++++++++++++++++++++++++++ crates/evil-keys/src/error.rs | 78 ++++ crates/evil-keys/src/key.rs | 443 +++++++++++++++++++++ crates/evil-keys/src/lib.rs | 13 + crates/evil-keys/src/timeout.rs | 101 +++++ crates/evil-keys/src/trie.rs | 414 ++++++++++++++++++++ crates/evil-keys/src/which_key.rs | 111 ++++++ 9 files changed, 1900 insertions(+) create mode 100644 crates/evil-keys/Cargo.toml create mode 100644 crates/evil-keys/src/count.rs create mode 100644 crates/evil-keys/src/dispatch.rs create mode 100644 crates/evil-keys/src/error.rs create mode 100644 crates/evil-keys/src/key.rs create mode 100644 crates/evil-keys/src/lib.rs create mode 100644 crates/evil-keys/src/timeout.rs create mode 100644 crates/evil-keys/src/trie.rs create mode 100644 crates/evil-keys/src/which_key.rs diff --git a/crates/evil-keys/Cargo.toml b/crates/evil-keys/Cargo.toml new file mode 100644 index 0000000..8c470d4 --- /dev/null +++ b/crates/evil-keys/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "evil-keys" +version = "0.1.0" +edition = "2021" + +[dependencies] +crossterm = "0.28" +indexmap = "2" + +[dev-dependencies] +proptest = "1.4" diff --git a/crates/evil-keys/src/count.rs b/crates/evil-keys/src/count.rs new file mode 100644 index 0000000..375046b --- /dev/null +++ b/crates/evil-keys/src/count.rs @@ -0,0 +1,114 @@ +pub struct CountState { + digits: String, +} + +impl CountState { + pub fn new() -> Self { + Self { + digits: String::new(), + } + } + + pub fn push_digit(&mut self, d: char) { + self.digits.push(d); + } + + pub fn take(&mut self) -> usize { + if self.digits.is_empty() { + return 1; + } + let val = self.digits.parse::().unwrap_or(usize::MAX); + self.digits.clear(); + val + } + + pub fn is_active(&self) -> bool { + !self.digits.is_empty() + } + + pub fn display(&self) -> &str { + &self.digits + } + + pub fn reset(&mut self) { + self.digits.clear(); + } +} + +impl Default for CountState { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_take() { + let mut count = CountState::new(); + assert_eq!(count.take(), 1); + } + + #[test] + fn test_push_take() { + let mut count = CountState::new(); + count.push_digit('5'); + assert_eq!(count.take(), 5); + assert_eq!(count.take(), 1); + } + + #[test] + fn test_multi_digit() { + let mut count = CountState::new(); + count.push_digit('1'); + count.push_digit('2'); + count.push_digit('3'); + assert_eq!(count.take(), 123); + } + + #[test] + fn test_leading_one() { + let mut count = CountState::new(); + count.push_digit('1'); + count.push_digit('0'); + assert_eq!(count.take(), 10); + } + + #[test] + fn test_saturate() { + let mut count = CountState::new(); + for _ in 0..20 { + count.push_digit('9'); + } + assert_eq!(count.take(), usize::MAX); + } + + #[test] + fn test_is_active() { + let mut count = CountState::new(); + assert!(!count.is_active()); + count.push_digit('5'); + assert!(count.is_active()); + count.take(); + assert!(!count.is_active()); + } + + #[test] + fn test_display() { + let mut count = CountState::new(); + assert_eq!(count.display(), ""); + count.push_digit('5'); + assert_eq!(count.display(), "5"); + } + + #[test] + fn test_reset() { + let mut count = CountState::new(); + count.push_digit('5'); + count.reset(); + assert!(!count.is_active()); + assert_eq!(count.display(), ""); + } +} diff --git a/crates/evil-keys/src/dispatch.rs b/crates/evil-keys/src/dispatch.rs new file mode 100644 index 0000000..fb2105c --- /dev/null +++ b/crates/evil-keys/src/dispatch.rs @@ -0,0 +1,615 @@ +use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; +use std::collections::HashMap; +use std::time::Duration; + +use crate::count::CountState; +use crate::error::ModeError; +use crate::key::Key; +use crate::timeout::TimeoutTracker; +use crate::trie::{KeyTrie, SearchResult}; +use crate::which_key::WhichKeyEntry; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DispatchResult { + Matched { action: A, count: usize }, + Pending, + Cancelled, + CountAccumulated, + Ignored, + NotFound, +} + +pub struct Dispatcher { + modes: HashMap>, + active_mode: String, + pending: Vec, + count: CountState, + timeout: TimeoutTracker, +} + +impl Dispatcher { + pub fn new() -> Self { + Self { + modes: HashMap::new(), + active_mode: String::new(), + pending: Vec::new(), + count: CountState::new(), + timeout: TimeoutTracker::new(Duration::from_secs(1)), + } + } + + pub fn add_mode(&mut self, name: &str, keymap: KeyTrie) -> Result<(), ModeError> { + self.modes.insert(name.to_string(), keymap); + Ok(()) + } + + pub fn set_active(&mut self, mode: &str) -> Result<(), ModeError> { + if !self.modes.contains_key(mode) { + return Err(ModeError::UnknownMode(mode.to_string())); + } + self.active_mode = mode.to_string(); + self.pending.clear(); + self.count.reset(); + self.timeout.reset(); + Ok(()) + } + + pub fn active_mode(&self) -> &str { + &self.active_mode + } + + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = TimeoutTracker::new(timeout); + } + + pub fn dispatch(&mut self, event: KeyEvent) -> DispatchResult { + if event.kind != KeyEventKind::Press { + return DispatchResult::Ignored; + } + + let key = Key::from(event); + + let is_escape = key.code == KeyCode::Esc && key.modifiers == KeyModifiers::NONE; + if is_escape && (!self.pending.is_empty() || self.count.is_active()) { + self.pending.clear(); + self.count.reset(); + self.timeout.reset(); + return DispatchResult::Cancelled; + } + + if let KeyCode::Char(c) = key.code { + if key.modifiers == KeyModifiers::NONE && self.pending.is_empty() { + let is_count_digit = c.is_ascii_digit() && (c != '0' || self.count.is_active()); + if is_count_digit { + self.count.push_digit(c); + return DispatchResult::CountAccumulated; + } + } + } + + self.pending.push(key); + + let Some(trie) = self.modes.get(&self.active_mode) else { + self.pending.clear(); + self.count.reset(); + self.timeout.reset(); + return DispatchResult::NotFound; + }; + + match trie.search(&self.pending) { + SearchResult::Found(leaf) => { + let action = leaf.action.clone(); + let count = self.count.take(); + self.pending.clear(); + self.timeout.reset(); + DispatchResult::Matched { action, count } + } + SearchResult::Prefix(_) => { + if !self.timeout.is_active() { + self.timeout.start(); + } + DispatchResult::Pending + } + SearchResult::NotFound => { + self.pending.clear(); + self.count.reset(); + self.timeout.reset(); + DispatchResult::NotFound + } + } + } + + pub fn check_timeout(&mut self) -> bool { + if self.timeout.check() { + self.pending.clear(); + self.count.reset(); + return true; + } + false + } + + pub fn pending_keys(&self) -> &[Key] { + &self.pending + } + + pub fn pending_display(&self) -> String { + self.pending + .iter() + .map(|k| k.to_string()) + .collect::>() + .join(" ") + } + + pub fn pending_elapsed(&self) -> Duration { + self.timeout.elapsed() + } + + pub fn clear_pending(&mut self) { + self.pending.clear(); + self.count.reset(); + self.timeout.reset(); + } + + pub fn which_key_entries(&self) -> Option> { + if self.pending.is_empty() { + return None; + } + + let trie = self.modes.get(&self.active_mode)?; + + if let SearchResult::Prefix(node) = trie.search(&self.pending) { + Some(node.which_key_entries()) + } else { + None + } + } + + pub fn count_display(&self) -> Option<&str> { + if self.count.is_active() { + Some(self.count.display()) + } else { + None + } + } +} + +impl Default for Dispatcher { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::trie::KeyTrie; + use crossterm::event::KeyEventState; + + fn press(code: KeyCode) -> KeyEvent { + KeyEvent { + code, + modifiers: KeyModifiers::NONE, + kind: KeyEventKind::Press, + state: KeyEventState::NONE, + } + } + + fn release(code: KeyCode) -> KeyEvent { + KeyEvent { + code, + modifiers: KeyModifiers::NONE, + kind: KeyEventKind::Release, + state: KeyEventState::NONE, + } + } + + #[test] + fn test_dispatch_non_press() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(release(KeyCode::Char('j'))), + DispatchResult::Ignored + ); + } + + #[test] + fn test_dispatch_no_modes() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::NotFound + ); + } + + #[test] + fn test_simple_match() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::Matched { + action: "down", + count: 1 + } + ); + } + + #[test] + fn test_sequence_match() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g g", "goto_top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Matched { + action: "goto_top", + count: 1 + } + ); + } + + #[test] + fn test_wrong_key_mid_sequence() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g g", "goto_top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('h'))), + DispatchResult::NotFound + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + } + + #[test] + fn test_escape_nothing_pending_bound() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("Esc", "escape_action").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Matched { + action: "escape_action", + count: 1 + } + ); + } + + #[test] + fn test_escape_clears_pending() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g g", "goto_top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + } + + #[test] + fn test_escape_clears_count() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('5'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + } + + #[test] + fn test_escape_clears_both() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g g", "goto_top").unwrap(); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('3'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::Matched { + action: "down", + count: 1 + } + ); + } + + #[test] + fn test_count_zero_as_binding() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("0", "start_of_line").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('0'))), + DispatchResult::Matched { + action: "start_of_line", + count: 1 + } + ); + } + + #[test] + fn test_count_10() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('1'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('0'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::Matched { + action: "down", + count: 10 + } + ); + } + + #[test] + fn test_count_through_prefix() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g g", "goto_top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('5'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Matched { + action: "goto_top", + count: 5 + } + ); + } + + #[test] + fn test_digits_during_pending() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("g 3", "some_action").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('3'))), + DispatchResult::Matched { + action: "some_action", + count: 1 + } + ); + } + + #[test] + fn test_mode_switch_clears() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut normal = KeyTrie::new("normal"); + normal.bind("g g", "goto_top").unwrap(); + let insert = KeyTrie::new("insert"); + + disp.add_mode("normal", normal).unwrap(); + disp.add_mode("insert", insert).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + disp.set_active("insert").unwrap(); + assert!(disp.pending_keys().is_empty()); + } + + #[test] + fn test_which_key_entries_after_spc() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.group("SPC", "leader", |node| { + node.bind_desc("b", "buffers", "Buffers")?; + node.bind_desc("f", "files", "Files")?; + Ok(()) + }) + .unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char(' '))), + DispatchResult::Pending + ); + let entries = disp.which_key_entries().unwrap(); + assert!(!entries.is_empty()); + } + + #[test] + fn test_which_key_entries_nothing_pending() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let trie = KeyTrie::new("normal"); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert!(disp.which_key_entries().is_none()); + } + + #[test] + fn test_pending_display() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("SPC b l", "list_buffers").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + disp.dispatch(press(KeyCode::Char(' '))); + assert_eq!(disp.pending_display(), "SPC"); + + disp.dispatch(press(KeyCode::Char('b'))); + assert_eq!(disp.pending_display(), "SPC b"); + } + + #[test] + fn test_count_display() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert!(disp.count_display().is_none()); + disp.dispatch(press(KeyCode::Char('5'))); + assert_eq!(disp.count_display(), Some("5")); + } + + #[test] + fn test_timeout_clears_count() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + trie.bind("g g", "top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + disp.set_timeout(std::time::Duration::from_millis(50)); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('5'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + std::thread::sleep(std::time::Duration::from_millis(80)); + assert!(disp.check_timeout()); + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::Matched { + action: "down", + count: 1 + } + ); + } + + #[test] + fn test_escape_full_progression() { + let mut disp: Dispatcher<&str> = Dispatcher::new(); + let mut trie = KeyTrie::new("normal"); + trie.bind("j", "down").unwrap(); + trie.bind("g g", "top").unwrap(); + disp.add_mode("normal", trie).unwrap(); + disp.set_active("normal").unwrap(); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('5'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('3'))), + DispatchResult::CountAccumulated + ); + assert_eq!( + disp.dispatch(press(KeyCode::Char('g'))), + DispatchResult::Pending + ); + assert_eq!( + disp.dispatch(press(KeyCode::Esc)), + DispatchResult::Cancelled + ); + + assert_eq!( + disp.dispatch(press(KeyCode::Char('j'))), + DispatchResult::Matched { + action: "down", + count: 1 + } + ); + } +} diff --git a/crates/evil-keys/src/error.rs b/crates/evil-keys/src/error.rs new file mode 100644 index 0000000..40c6dfb --- /dev/null +++ b/crates/evil-keys/src/error.rs @@ -0,0 +1,78 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ParseError { + EmptyInput, + UnknownKey(String), + DanglingModifier, + DuplicateModifier, + RedundantShift, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::EmptyInput => write!(f, "empty input"), + Self::UnknownKey(k) => write!(f, "unknown key: {k}"), + Self::DanglingModifier => write!(f, "dangling modifier (e.g. \"C-\")"), + Self::DuplicateModifier => write!(f, "duplicate modifier"), + Self::RedundantShift => write!( + f, + "redundant shift on uppercase char (use \"G\" not \"S-G\")" + ), + } + } +} + +impl std::error::Error for ParseError {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BindError { + EmptySequence, + ConflictWithLeaf { existing_keys: String }, + ConflictWithPrefix { existing_keys: String }, + Parse(ParseError), +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::EmptySequence => write!(f, "empty key sequence"), + Self::ConflictWithLeaf { existing_keys } => { + write!( + f, + "conflicts with existing leaf binding at \"{existing_keys}\"" + ) + } + Self::ConflictWithPrefix { existing_keys } => { + write!(f, "conflicts with existing prefix at \"{existing_keys}\"") + } + Self::Parse(e) => write!(f, "parse error: {e}"), + } + } +} + +impl std::error::Error for BindError {} + +impl From for BindError { + fn from(e: ParseError) -> Self { + Self::Parse(e) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ModeError { + UnknownMode(String), + DuplicateMode(String), +} + +impl fmt::Display for ModeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnknownMode(m) => write!(f, "unknown mode: \"{m}\""), + Self::DuplicateMode(m) => write!(f, "mode already exists: \"{m}\""), + } + } +} + +impl std::error::Error for ModeError {} diff --git a/crates/evil-keys/src/key.rs b/crates/evil-keys/src/key.rs new file mode 100644 index 0000000..07c1c7b --- /dev/null +++ b/crates/evil-keys/src/key.rs @@ -0,0 +1,443 @@ +use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use std::fmt; +use std::hash::{Hash, Hasher}; + +use crate::error::ParseError; + +#[derive(Clone, Copy, Debug)] +pub struct Key { + pub code: KeyCode, + pub modifiers: KeyModifiers, +} + +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.code == other.code && self.modifiers == other.modifiers + } +} + +impl Eq for Key {} + +impl Hash for Key { + fn hash(&self, state: &mut H) { + hash_key_code(&self.code, state); + self.modifiers.bits().hash(state); + } +} + +fn hash_key_code(code: &KeyCode, state: &mut H) { + match code { + KeyCode::Backspace => (0u8).hash(state), + KeyCode::Enter => (1u8).hash(state), + KeyCode::Left => (2u8).hash(state), + KeyCode::Right => (3u8).hash(state), + KeyCode::Up => (4u8).hash(state), + KeyCode::Down => (5u8).hash(state), + KeyCode::Home => (6u8).hash(state), + KeyCode::End => (7u8).hash(state), + KeyCode::PageUp => (8u8).hash(state), + KeyCode::PageDown => (9u8).hash(state), + KeyCode::Tab => (10u8).hash(state), + KeyCode::BackTab => (11u8).hash(state), + KeyCode::Delete => (12u8).hash(state), + KeyCode::Insert => (13u8).hash(state), + KeyCode::F(n) => { + (14u8).hash(state); + n.hash(state); + } + KeyCode::Char(c) => { + (15u8).hash(state); + c.hash(state); + } + KeyCode::Null => (16u8).hash(state), + KeyCode::Esc => (17u8).hash(state), + KeyCode::CapsLock => (18u8).hash(state), + KeyCode::ScrollLock => (19u8).hash(state), + KeyCode::NumLock => (20u8).hash(state), + KeyCode::PrintScreen => (21u8).hash(state), + KeyCode::Pause => (22u8).hash(state), + KeyCode::Menu => (23u8).hash(state), + KeyCode::KeypadBegin => (24u8).hash(state), + KeyCode::Media(m) => { + (25u8).hash(state); + (*m as u8).hash(state); + } + KeyCode::Modifier(m) => { + (26u8).hash(state); + (*m as u8).hash(state); + } + } +} + +impl From for Key { + fn from(e: KeyEvent) -> Self { + Key { + code: e.code, + modifiers: e.modifiers, + } + } +} + +impl fmt::Display for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.modifiers.contains(KeyModifiers::CONTROL) { + write!(f, "C-")?; + } + if self.modifiers.contains(KeyModifiers::SHIFT) { + write!(f, "S-")?; + } + if self.modifiers.contains(KeyModifiers::ALT) { + write!(f, "A-")?; + } + + match self.code { + KeyCode::Char(' ') => write!(f, "SPC"), + KeyCode::Char(c) => write!(f, "{c}"), + KeyCode::Esc => write!(f, "Esc"), + KeyCode::Tab => write!(f, "Tab"), + KeyCode::Enter => write!(f, "Enter"), + KeyCode::Backspace => write!(f, "Backspace"), + KeyCode::F(n) => write!(f, "F{n}"), + KeyCode::Left => write!(f, "Left"), + KeyCode::Right => write!(f, "Right"), + KeyCode::Up => write!(f, "Up"), + KeyCode::Down => write!(f, "Down"), + KeyCode::Home => write!(f, "Home"), + KeyCode::End => write!(f, "End"), + KeyCode::PageUp => write!(f, "PageUp"), + KeyCode::PageDown => write!(f, "PageDown"), + KeyCode::Delete => write!(f, "Delete"), + KeyCode::Insert => write!(f, "Insert"), + _ => write!(f, "?"), + } + } +} + +pub fn parse_key(input: &str) -> Result { + if input.trim().is_empty() { + return Err(ParseError::EmptyInput); + } + + if input.contains(' ') { + return Err(ParseError::UnknownKey(input.to_string())); + } + + let trimmed = input; + + let mut modifiers = KeyModifiers::NONE; + let mut remaining = trimmed; + let mut has_ctrl = false; + let mut has_shift = false; + let mut has_alt = false; + + loop { + if remaining.starts_with("C-") { + if has_ctrl { + return Err(ParseError::DuplicateModifier); + } + has_ctrl = true; + modifiers |= KeyModifiers::CONTROL; + remaining = &remaining[2..]; + } else if remaining.starts_with("S-") { + if has_shift { + return Err(ParseError::DuplicateModifier); + } + has_shift = true; + modifiers |= KeyModifiers::SHIFT; + remaining = &remaining[2..]; + } else if remaining.starts_with("A-") { + if has_alt { + return Err(ParseError::DuplicateModifier); + } + has_alt = true; + modifiers |= KeyModifiers::ALT; + remaining = &remaining[2..]; + } else { + break; + } + } + + if remaining.is_empty() { + return Err(ParseError::DanglingModifier); + } + + let code = match remaining { + "SPC" => KeyCode::Char(' '), + "Esc" => KeyCode::Esc, + "Tab" => KeyCode::Tab, + "Enter" => KeyCode::Enter, + "Backspace" => KeyCode::Backspace, + "Left" => KeyCode::Left, + "Right" => KeyCode::Right, + "Up" => KeyCode::Up, + "Down" => KeyCode::Down, + "Home" => KeyCode::Home, + "End" => KeyCode::End, + "PageUp" => KeyCode::PageUp, + "PageDown" => KeyCode::PageDown, + "Delete" => KeyCode::Delete, + "Insert" => KeyCode::Insert, + "-" => KeyCode::Char('-'), + s if s.starts_with('F') && s.len() > 1 => { + let num_str = &s[1..]; + match num_str.parse::() { + Ok(n) if (1..=12).contains(&n) => KeyCode::F(n), + _ => return Err(ParseError::UnknownKey(input.to_string())), + } + } + s if s.len() == 1 => { + let c = s.chars().next().expect("non-empty string"); + if !c.is_ascii() { + return Err(ParseError::UnknownKey(input.to_string())); + } + if has_shift && c.is_ascii_uppercase() { + return Err(ParseError::RedundantShift); + } + KeyCode::Char(c) + } + _ => return Err(ParseError::UnknownKey(input.to_string())), + }; + + Ok(Key { code, modifiers }) +} + +pub fn parse_sequence(input: &str) -> Result, ParseError> { + let trimmed = input.trim(); + + if trimmed.is_empty() { + return Err(ParseError::EmptyInput); + } + + trimmed.split_whitespace().map(parse_key).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_input() { + assert_eq!(parse_key(""), Err(ParseError::EmptyInput)); + assert_eq!(parse_key(" "), Err(ParseError::EmptyInput)); + } + + #[test] + fn test_space_in_key() { + assert!(matches!(parse_key(" j"), Err(ParseError::UnknownKey(_)))); + } + + #[test] + fn test_dangling_modifier() { + assert_eq!(parse_key("C-"), Err(ParseError::DanglingModifier)); + } + + #[test] + fn test_duplicate_modifier() { + assert_eq!(parse_key("C-C-d"), Err(ParseError::DuplicateModifier)); + } + + #[test] + fn test_lowercase_modifier() { + assert!(matches!(parse_key("c-d"), Err(ParseError::UnknownKey(_)))); + } + + #[test] + fn test_ctrl_d() { + let key = parse_key("C-d").unwrap(); + assert_eq!(key.code, KeyCode::Char('d')); + assert_eq!(key.modifiers, KeyModifiers::CONTROL); + } + + #[test] + fn test_ctrl_shift_d() { + let key = parse_key("C-S-d").unwrap(); + assert_eq!(key.code, KeyCode::Char('d')); + assert_eq!(key.modifiers, KeyModifiers::CONTROL | KeyModifiers::SHIFT); + } + + #[test] + fn test_modifier_order_normalized() { + let key1 = parse_key("S-C-d").unwrap(); + let key2 = parse_key("C-S-d").unwrap(); + assert_eq!(key1, key2); + } + + #[test] + fn test_uppercase_no_shift() { + let key = parse_key("G").unwrap(); + assert_eq!(key.code, KeyCode::Char('G')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + } + + #[test] + fn test_shift_lowercase() { + let key = parse_key("S-g").unwrap(); + assert_eq!(key.code, KeyCode::Char('g')); + assert_eq!(key.modifiers, KeyModifiers::SHIFT); + } + + #[test] + fn test_redundant_shift() { + assert_eq!(parse_key("S-G"), Err(ParseError::RedundantShift)); + } + + #[test] + fn test_bare_modifier_letters() { + let key = parse_key("C").unwrap(); + assert_eq!(key.code, KeyCode::Char('C')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("S").unwrap(); + assert_eq!(key.code, KeyCode::Char('S')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + } + + #[test] + fn test_special_keys() { + let key = parse_key("SPC").unwrap(); + assert_eq!(key.code, KeyCode::Char(' ')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + assert!(matches!(parse_key("spc"), Err(ParseError::UnknownKey(_)))); + + let key = parse_key("Esc").unwrap(); + assert_eq!(key.code, KeyCode::Esc); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("Tab").unwrap(); + assert_eq!(key.code, KeyCode::Tab); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("Enter").unwrap(); + assert_eq!(key.code, KeyCode::Enter); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + // Tab != 't' + assert_ne!(parse_key("Tab").unwrap(), parse_key("t").unwrap()); + } + + #[test] + fn test_function_keys() { + let key = parse_key("F1").unwrap(); + assert_eq!(key.code, KeyCode::F(1)); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("F12").unwrap(); + assert_eq!(key.code, KeyCode::F(12)); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + assert!(matches!(parse_key("F0"), Err(ParseError::UnknownKey(_)))); + assert!(matches!(parse_key("F13"), Err(ParseError::UnknownKey(_)))); + + let key = parse_key("F").unwrap(); + assert_eq!(key.code, KeyCode::Char('F')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("C-F1").unwrap(); + assert_eq!(key.code, KeyCode::F(1)); + assert_eq!(key.modifiers, KeyModifiers::CONTROL); + } + + #[test] + fn test_symbols() { + let key = parse_key("-").unwrap(); + assert_eq!(key.code, KeyCode::Char('-')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("C--").unwrap(); + assert_eq!(key.code, KeyCode::Char('-')); + assert_eq!(key.modifiers, KeyModifiers::CONTROL); + + let key = parse_key("[").unwrap(); + assert_eq!(key.code, KeyCode::Char('[')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + + let key = parse_key("?").unwrap(); + assert_eq!(key.code, KeyCode::Char('?')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + } + + #[test] + fn test_digits() { + for c in '0'..='9' { + let key = parse_key(&c.to_string()).unwrap(); + assert_eq!(key.code, KeyCode::Char(c)); + assert_eq!(key.modifiers, KeyModifiers::NONE); + } + } + + #[test] + fn test_invalid_non_ascii() { + assert!(matches!(parse_key("é"), Err(ParseError::UnknownKey(_)))); + } + + #[test] + fn test_space_is_sequence() { + assert!(matches!(parse_key("g g"), Err(ParseError::UnknownKey(_)))); + } + + #[test] + fn test_display_roundtrip() { + let test_cases = [ + "j", "G", "C-d", "C-S-d", "SPC", "Esc", "Tab", "Enter", "F1", "-", "C--", "[", + ]; + for input in test_cases { + let key = parse_key(input).unwrap(); + let displayed = key.to_string(); + let reparsed = parse_key(&displayed).unwrap(); + assert_eq!(key, reparsed, "roundtrip failed for {input}"); + } + } + + #[test] + fn test_display_specific() { + let key = Key { + code: KeyCode::Char(' '), + modifiers: KeyModifiers::NONE, + }; + assert_eq!(key.to_string(), "SPC"); + + let key = Key { + code: KeyCode::Char('d'), + modifiers: KeyModifiers::CONTROL | KeyModifiers::SHIFT, + }; + assert_eq!(key.to_string(), "C-S-d"); + } + + #[test] + fn test_sequence_parsing() { + let seq = parse_sequence("g g").unwrap(); + assert_eq!(seq.len(), 2); + assert_eq!(seq[0].code, KeyCode::Char('g')); + assert_eq!(seq[1].code, KeyCode::Char('g')); + + let seq = parse_sequence("SPC b l").unwrap(); + assert_eq!(seq.len(), 3); + assert_eq!(seq[0].code, KeyCode::Char(' ')); + assert_eq!(seq[1].code, KeyCode::Char('b')); + assert_eq!(seq[2].code, KeyCode::Char('l')); + + assert_eq!(parse_sequence(""), Err(ParseError::EmptyInput)); + assert_eq!(parse_sequence(" "), Err(ParseError::EmptyInput)); + } + + #[test] + fn test_from_key_event() { + let event = KeyEvent::new(KeyCode::Char('j'), KeyModifiers::NONE); + let key = Key::from(event); + assert_eq!(key.code, KeyCode::Char('j')); + assert_eq!(key.modifiers, KeyModifiers::NONE); + } + + #[test] + fn test_key_hash() { + use std::collections::HashMap; + + let mut map = HashMap::new(); + let key1 = parse_key("C-d").unwrap(); + let key2 = parse_key("C-d").unwrap(); + + map.insert(key1, "action"); + assert_eq!(map.get(&key2), Some(&"action")); + } +} diff --git a/crates/evil-keys/src/lib.rs b/crates/evil-keys/src/lib.rs new file mode 100644 index 0000000..12fba2e --- /dev/null +++ b/crates/evil-keys/src/lib.rs @@ -0,0 +1,13 @@ +pub mod count; +pub mod dispatch; +pub mod error; +pub mod key; +pub mod timeout; +pub mod trie; +pub mod which_key; + +pub use dispatch::{DispatchResult, Dispatcher}; +pub use error::{BindError, ModeError, ParseError}; +pub use key::Key; +pub use trie::{KeyTrie, KeyTrieNode, LeafBinding, SearchResult}; +pub use which_key::WhichKeyEntry; diff --git a/crates/evil-keys/src/timeout.rs b/crates/evil-keys/src/timeout.rs new file mode 100644 index 0000000..5d79fb0 --- /dev/null +++ b/crates/evil-keys/src/timeout.rs @@ -0,0 +1,101 @@ +use std::time::{Duration, Instant}; + +pub struct TimeoutTracker { + timeout: Duration, + started_at: Option, +} + +impl TimeoutTracker { + pub fn new(timeout: Duration) -> Self { + Self { + timeout, + started_at: None, + } + } + + pub fn start(&mut self) { + self.started_at = Some(Instant::now()); + } + + pub fn check(&mut self) -> bool { + if let Some(started) = self.started_at { + if started.elapsed() >= self.timeout { + self.started_at = None; + return true; + } + } + false + } + + pub fn elapsed(&self) -> Duration { + self.started_at + .map(|s| s.elapsed()) + .unwrap_or(Duration::ZERO) + } + + pub fn reset(&mut self) { + self.started_at = None; + } + + pub fn is_active(&self) -> bool { + self.started_at.is_some() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + #[test] + fn test_check_nothing_started() { + let mut tracker = TimeoutTracker::new(Duration::from_millis(100)); + assert!(!tracker.check()); + } + + #[test] + fn test_check_expired() { + let mut tracker = TimeoutTracker::new(Duration::from_millis(100)); + tracker.start(); + thread::sleep(Duration::from_millis(150)); + assert!(tracker.check()); + } + + #[test] + fn test_check_not_expired() { + let mut tracker = TimeoutTracker::new(Duration::from_millis(100)); + tracker.start(); + assert!(!tracker.check()); + } + + #[test] + fn test_reset() { + let mut tracker = TimeoutTracker::new(Duration::from_millis(100)); + tracker.start(); + tracker.reset(); + assert!(!tracker.check()); + } + + #[test] + fn test_elapsed_nothing_started() { + let tracker = TimeoutTracker::new(Duration::from_millis(100)); + assert_eq!(tracker.elapsed(), Duration::ZERO); + } + + #[test] + fn test_zero_timeout() { + let mut tracker = TimeoutTracker::new(Duration::ZERO); + tracker.start(); + assert!(tracker.check()); + } + + #[test] + fn test_is_active() { + let mut tracker = TimeoutTracker::new(Duration::from_millis(100)); + assert!(!tracker.is_active()); + tracker.start(); + assert!(tracker.is_active()); + tracker.reset(); + assert!(!tracker.is_active()); + } +} diff --git a/crates/evil-keys/src/trie.rs b/crates/evil-keys/src/trie.rs new file mode 100644 index 0000000..e0e7e1f --- /dev/null +++ b/crates/evil-keys/src/trie.rs @@ -0,0 +1,414 @@ +use crate::error::BindError; +use crate::key::{parse_sequence, Key}; +use indexmap::IndexMap; + +pub enum KeyTrie { + Leaf(LeafBinding), + Node(KeyTrieNode), +} + +pub struct LeafBinding { + pub action: A, + pub description: Option, +} + +pub struct KeyTrieNode { + pub name: String, + pub map: IndexMap>, +} + +pub enum SearchResult<'a, A> { + Found(&'a LeafBinding), + Prefix(&'a KeyTrieNode), + NotFound, +} + +impl KeyTrie { + pub fn new(name: &str) -> Self { + KeyTrie::Node(KeyTrieNode { + name: name.to_string(), + map: IndexMap::new(), + }) + } + + pub fn bind(&mut self, keys: &str, action: A) -> Result<(), BindError> { + self.bind_internal(keys, action, None) + } + + pub fn bind_desc(&mut self, keys: &str, action: A, desc: &str) -> Result<(), BindError> { + self.bind_internal(keys, action, Some(desc.to_string())) + } + + fn bind_internal( + &mut self, + keys: &str, + action: A, + description: Option, + ) -> Result<(), BindError> { + let sequence = parse_sequence(keys)?; + + if sequence.is_empty() { + return Err(BindError::EmptySequence); + } + + let node = match self { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => { + return Err(BindError::ConflictWithLeaf { + existing_keys: String::new(), + }) + } + }; + + node.bind_sequence(&sequence, action, description, String::new()) + } + + pub fn group(&mut self, key: &str, name: &str, f: F) -> Result<(), BindError> + where + F: FnOnce(&mut KeyTrieNode) -> Result<(), BindError>, + { + let parsed_key = crate::key::parse_key(key)?; + + let node = match self { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => { + return Err(BindError::ConflictWithLeaf { + existing_keys: String::new(), + }) + } + }; + + let entry = node.map.entry(parsed_key); + let child_node = match entry { + indexmap::map::Entry::Occupied(o) => match o.into_mut() { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => { + return Err(BindError::ConflictWithLeaf { + existing_keys: key.to_string(), + }) + } + }, + indexmap::map::Entry::Vacant(v) => { + let new_node = KeyTrie::Node(KeyTrieNode { + name: name.to_string(), + map: IndexMap::new(), + }); + match v.insert(new_node) { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => unreachable!(), + } + } + }; + + f(child_node) + } + + pub fn search(&self, keys: &[Key]) -> SearchResult<'_, A> { + if keys.is_empty() { + return SearchResult::NotFound; + } + + let node = match self { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(l) => return SearchResult::Found(l), + }; + + node.search(keys) + } +} + +impl KeyTrieNode { + pub fn bind(&mut self, keys: &str, action: A) -> Result<(), BindError> { + let sequence = parse_sequence(keys)?; + + if sequence.is_empty() { + return Err(BindError::EmptySequence); + } + + self.bind_sequence(&sequence, action, None, String::new()) + } + + pub fn bind_desc(&mut self, keys: &str, action: A, desc: &str) -> Result<(), BindError> { + let sequence = parse_sequence(keys)?; + + if sequence.is_empty() { + return Err(BindError::EmptySequence); + } + + self.bind_sequence(&sequence, action, Some(desc.to_string()), String::new()) + } + + fn bind_sequence( + &mut self, + keys: &[Key], + action: A, + description: Option, + path: String, + ) -> Result<(), BindError> { + let (first, rest) = keys.split_first().expect("non-empty keys"); + let current_path = if path.is_empty() { + first.to_string() + } else { + format!("{} {}", path, first) + }; + + if rest.is_empty() { + match self.map.get(first) { + Some(KeyTrie::Node(_)) => { + return Err(BindError::ConflictWithPrefix { + existing_keys: current_path, + }); + } + Some(KeyTrie::Leaf(_)) | None => { + self.map.insert( + *first, + KeyTrie::Leaf(LeafBinding { + action, + description, + }), + ); + return Ok(()); + } + } + } + + let entry = self.map.entry(*first); + match entry { + indexmap::map::Entry::Occupied(mut o) => match o.get_mut() { + KeyTrie::Leaf(_) => { + return Err(BindError::ConflictWithLeaf { + existing_keys: current_path, + }); + } + KeyTrie::Node(n) => { + n.bind_sequence(rest, action, description, current_path)?; + } + }, + indexmap::map::Entry::Vacant(v) => { + let mut new_node = KeyTrieNode { + name: String::new(), + map: IndexMap::new(), + }; + new_node.bind_sequence(rest, action, description, current_path)?; + v.insert(KeyTrie::Node(new_node)); + } + } + + Ok(()) + } + + pub fn search(&self, keys: &[Key]) -> SearchResult<'_, A> { + if keys.is_empty() { + return SearchResult::NotFound; + } + + let (first, rest) = keys.split_first().expect("non-empty keys"); + + match self.map.get(first) { + None => SearchResult::NotFound, + Some(KeyTrie::Leaf(l)) if rest.is_empty() => SearchResult::Found(l), + Some(KeyTrie::Leaf(_)) => SearchResult::NotFound, + Some(KeyTrie::Node(n)) if rest.is_empty() => SearchResult::Prefix(n), + Some(KeyTrie::Node(n)) => n.search(rest), + } + } + + pub fn group(&mut self, key: &str, name: &str, f: F) -> Result<(), BindError> + where + F: FnOnce(&mut KeyTrieNode) -> Result<(), BindError>, + { + let parsed_key = crate::key::parse_key(key)?; + + let entry = self.map.entry(parsed_key); + let child_node = match entry { + indexmap::map::Entry::Occupied(o) => match o.into_mut() { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => { + return Err(BindError::ConflictWithLeaf { + existing_keys: key.to_string(), + }) + } + }, + indexmap::map::Entry::Vacant(v) => { + let new_node = KeyTrie::Node(KeyTrieNode { + name: name.to_string(), + map: IndexMap::new(), + }); + match v.insert(new_node) { + KeyTrie::Node(n) => n, + KeyTrie::Leaf(_) => unreachable!(), + } + } + }; + + f(child_node) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::key::parse_key; + + #[test] + fn test_conflict_leaf_then_prefix() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g", "action").unwrap(); + + let result = trie.bind("g g", "action2"); + assert!(matches!(result, Err(BindError::ConflictWithLeaf { .. }))); + } + + #[test] + fn test_conflict_prefix_then_leaf() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g g", "action").unwrap(); + + let result = trie.bind("g", "action2"); + assert!(matches!(result, Err(BindError::ConflictWithPrefix { .. }))); + } + + #[test] + fn test_siblings_ok() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g g", "action1").unwrap(); + trie.bind("g h", "action2").unwrap(); + } + + #[test] + fn test_overwrite() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g", "first").unwrap(); + trie.bind("g", "second").unwrap(); + + let key = parse_key("g").unwrap(); + if let SearchResult::Found(leaf) = trie.search(&[key]) { + assert_eq!(leaf.action, "second"); + } else { + panic!("expected Found"); + } + } + + #[test] + fn test_empty_bind() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + let result = trie.bind("", "action"); + assert!(matches!(result, Err(BindError::Parse(_)))); + } + + #[test] + fn test_search_empty_keys() { + let trie: KeyTrie<&str> = KeyTrie::new("root"); + assert!(matches!(trie.search(&[]), SearchResult::NotFound)); + } + + #[test] + fn test_search_empty_trie() { + let trie: KeyTrie<&str> = KeyTrie::new("root"); + let key = parse_key("g").unwrap(); + assert!(matches!(trie.search(&[key]), SearchResult::NotFound)); + } + + #[test] + fn test_search_prefix() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g g", "action").unwrap(); + + let g = parse_key("g").unwrap(); + assert!(matches!(trie.search(&[g]), SearchResult::Prefix(_))); + } + + #[test] + fn test_search_found() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g g", "action").unwrap(); + + let g = parse_key("g").unwrap(); + if let SearchResult::Found(leaf) = trie.search(&[g, g]) { + assert_eq!(leaf.action, "action"); + } else { + panic!("expected Found"); + } + } + + #[test] + fn test_search_not_found_wrong_key() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g g", "action").unwrap(); + + let g = parse_key("g").unwrap(); + let h = parse_key("h").unwrap(); + assert!(matches!(trie.search(&[g, h]), SearchResult::NotFound)); + } + + #[test] + fn test_search_beyond_leaf() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("g", "action").unwrap(); + + let g = parse_key("g").unwrap(); + let h = parse_key("h").unwrap(); + assert!(matches!(trie.search(&[g, h]), SearchResult::NotFound)); + } + + #[test] + fn test_group_then_bind_at_group_key() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.group("g", "goto", |_| Ok(())).unwrap(); + + let result = trie.bind("g", "action"); + assert!(matches!(result, Err(BindError::ConflictWithPrefix { .. }))); + } + + #[test] + fn test_empty_group() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.group("g", "goto", |_| Ok(())).unwrap(); + + let g = parse_key("g").unwrap(); + if let SearchResult::Prefix(node) = trie.search(&[g]) { + assert!(node.map.is_empty()); + } else { + panic!("expected Prefix"); + } + } + + #[test] + fn test_deep_nesting() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind("a b c d e f g h", "deep").unwrap(); + + let keys: Vec = "a b c d e f g h" + .split_whitespace() + .map(|s| parse_key(s).unwrap()) + .collect(); + + assert!(matches!(trie.search(&keys), SearchResult::Found(_))); + } + + #[test] + fn test_wide_single_level() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + for c in 'a'..='z' { + trie.bind(&c.to_string(), "action").unwrap(); + } + + for c in 'a'..='z' { + let key = parse_key(&c.to_string()).unwrap(); + assert!(matches!(trie.search(&[key]), SearchResult::Found(_))); + } + } + + #[test] + fn test_bind_with_description() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind_desc("j", "down", "Move down").unwrap(); + + let j = parse_key("j").unwrap(); + if let SearchResult::Found(leaf) = trie.search(&[j]) { + assert_eq!(leaf.description, Some("Move down".to_string())); + } else { + panic!("expected Found"); + } + } +} diff --git a/crates/evil-keys/src/which_key.rs b/crates/evil-keys/src/which_key.rs new file mode 100644 index 0000000..ba65604 --- /dev/null +++ b/crates/evil-keys/src/which_key.rs @@ -0,0 +1,111 @@ +use crate::trie::{KeyTrie, KeyTrieNode}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WhichKeyEntry { + pub key: String, + pub description: String, + pub is_group: bool, +} + +impl KeyTrieNode { + pub fn which_key_entries(&self) -> Vec { + let mut groups = Vec::new(); + let mut leaves = Vec::new(); + + for (key, trie) in &self.map { + match trie { + KeyTrie::Node(node) => { + groups.push(WhichKeyEntry { + key: key.to_string(), + description: node.name.clone(), + is_group: true, + }); + } + KeyTrie::Leaf(leaf) => { + leaves.push(WhichKeyEntry { + key: key.to_string(), + description: leaf.description.clone().unwrap_or_default(), + is_group: false, + }); + } + } + } + + groups.sort_by(|a, b| a.key.cmp(&b.key)); + leaves.sort_by(|a, b| a.key.cmp(&b.key)); + groups.extend(leaves); + groups + } +} + +#[cfg(test)] +mod tests { + use crate::trie::KeyTrie; + + #[test] + fn test_empty_node() { + let trie: KeyTrie<&str> = KeyTrie::new("root"); + if let KeyTrie::Node(node) = trie { + let entries = node.which_key_entries(); + assert!(entries.is_empty()); + } + } + + #[test] + fn test_leaves_only_sorted() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind_desc("c", "action_c", "C action").unwrap(); + trie.bind_desc("a", "action_a", "A action").unwrap(); + trie.bind_desc("b", "action_b", "B action").unwrap(); + + if let KeyTrie::Node(node) = trie { + let entries = node.which_key_entries(); + assert_eq!(entries.len(), 3); + assert_eq!(entries[0].key, "a"); + assert_eq!(entries[1].key, "b"); + assert_eq!(entries[2].key, "c"); + } + } + + #[test] + fn test_groups_before_leaves() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind_desc("z", "action", "Z action").unwrap(); + trie.group("g", "goto", |node| { + node.bind("g", "goto_top")?; + Ok(()) + }) + .unwrap(); + trie.bind_desc("a", "action", "A action").unwrap(); + + if let KeyTrie::Node(node) = trie { + let entries = node.which_key_entries(); + assert_eq!(entries.len(), 3); + assert!(entries[0].is_group); + assert_eq!(entries[0].key, "g"); + assert!(!entries[1].is_group); + assert_eq!(entries[1].key, "a"); + assert!(!entries[2].is_group); + assert_eq!(entries[2].key, "z"); + } + } + + #[test] + fn test_descriptions_populated() { + let mut trie: KeyTrie<&str> = KeyTrie::new("root"); + trie.bind_desc("j", "down", "Move down").unwrap(); + trie.group("g", "goto", |_| Ok(())).unwrap(); + + if let KeyTrie::Node(node) = trie { + let entries = node.which_key_entries(); + + let group = entries.iter().find(|e| e.key == "g").unwrap(); + assert_eq!(group.description, "goto"); + assert!(group.is_group); + + let leaf = entries.iter().find(|e| e.key == "j").unwrap(); + assert_eq!(leaf.description, "Move down"); + assert!(!leaf.is_group); + } + } +}