From 5a76bc6ba1c6f65b43fd26d9417c97c6e3ba5aba Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 19 May 2026 11:07:23 +0900 Subject: [PATCH] Optimize global graph key hashing --- src/fragment.rs | 8 +-- src/materialize.rs | 17 ++--- src/types.rs | 161 ++++++++++++++++++++++++++++++++++++++++-- tests/scalar_tests.rs | 51 ++++++++++--- 4 files changed, 207 insertions(+), 30 deletions(-) diff --git a/src/fragment.rs b/src/fragment.rs index 35de3de..8c7e62d 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -116,17 +116,13 @@ impl FragmentBuilder { .map(|input| self.resolve_input_key(input)) .collect(); - let global_op_key = GlobalOpKey { - primitive: op.clone(), - inputs: global_inputs, - mode: mode.clone(), - }; + let global_op_key = Arc::new(GlobalOpKey::new(op.clone(), global_inputs, mode.clone())); let mut output_ids = Vec::with_capacity(n_outputs); for slot in 0..n_outputs { let val_id = self.vals.len(); let key = GlobalValKey::Derived { - op: global_op_key.clone(), + op: Arc::clone(&global_op_key), output_slot: slot as u8, }; self.vals.push(ValNode { diff --git a/src/materialize.rs b/src/materialize.rs index e8776bf..ec85e78 100644 --- a/src/materialize.rs +++ b/src/materialize.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::resolve::{ResolvedView, ValDef}; use crate::traits::GraphOp; @@ -30,7 +31,7 @@ pub struct MaterializedGraph { struct Materializer<'a, Op: GraphOp> { view: &'a ResolvedView, val_map: HashMap, usize>, - op_map: HashMap, usize>, + op_map: HashMap>, usize>, vals: Vec>, ops: Vec>, input_keys: Vec>, @@ -89,11 +90,11 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> { mode: OpMode, output_slot: usize, ) -> usize { - let op_key = GlobalOpKey { - primitive: op.clone(), - inputs: input_keys.clone(), - mode: mode.clone(), - }; + let op_key = Arc::new(GlobalOpKey::new( + op.clone(), + input_keys.clone(), + mode.clone(), + )); if self.op_map.contains_key(&op_key) { let output_key = GlobalValKey::Derived { @@ -115,7 +116,7 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> { let materialized_inputs = input_keys.iter().map(|input| self.visit(input)).collect(); let op_index = self.ops.len(); - self.op_map.insert(op_key.clone(), op_index); + self.op_map.insert(Arc::clone(&op_key), op_index); self.ops.push(MaterializedOp { op: op.clone(), inputs: materialized_inputs, @@ -125,7 +126,7 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> { for slot in 0..op.n_outputs() { let output_key = GlobalValKey::Derived { - op: op_key.clone(), + op: Arc::clone(&op_key), output_slot: slot as u8, }; let val_index = self.vals.len(); diff --git a/src/types.rs b/src/types.rs index 6e4950f..2b7462c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,7 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::traits::GraphOp; /// Fragment-local value identifier. @@ -21,19 +25,164 @@ pub enum ValRef { } /// Cross-fragment structural identity for a value. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug)] pub enum GlobalValKey { Input(Op::InputKey), Derived { - op: GlobalOpKey, + /// Shared structural identity of the operation that produced this value. + op: Arc>, output_slot: u8, }, } /// Cross-fragment structural identity for an operation. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +/// +/// `GlobalOpKey` caches a structural fingerprint so maps keyed by recursively +/// derived values can avoid repeatedly re-hashing the whole input tree. Equality +/// still checks the full structure after the fingerprint prefilter. +#[derive(Clone, Debug)] pub struct GlobalOpKey { - pub primitive: Op, - pub inputs: Vec>, - pub mode: OpMode, + primitive: Op, + inputs: Vec>, + mode: OpMode, + /// Cached hash prefilter for recursively structural keys. + /// + /// This is not an identity proof: equality still compares the full + /// structure after the fingerprint matches, so hash collisions remain + /// correct. + fingerprint: u64, +} + +impl GlobalOpKey { + /// Builds an operation key and precomputes its structural fingerprint. + /// + /// # Examples + /// + /// ```ignore + /// use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, OpMode}; + /// + /// #[derive(Clone, Debug, Hash, PartialEq, Eq)] + /// enum Op { + /// Add, + /// } + /// + /// impl GraphOp for Op { + /// type Operand = f64; + /// type Context = (); + /// type InputKey = &'static str; + /// + /// fn n_inputs(&self) -> usize { 2 } + /// fn n_outputs(&self) -> usize { 1 } + /// } + /// + /// let key = GlobalOpKey::new( + /// Op::Add, + /// vec![GlobalValKey::Input("x"), GlobalValKey::Input("y")], + /// OpMode::Primal, + /// ); + /// assert_eq!(key.inputs().len(), 2); + /// ``` + pub fn new(primitive: Op, inputs: Vec>, mode: OpMode) -> Self { + let fingerprint = fingerprint_op(&primitive, &inputs, &mode); + Self { + primitive, + inputs, + mode, + fingerprint, + } + } + + /// Returns the cached structural fingerprint. + pub fn fingerprint(&self) -> u64 { + self.fingerprint + } + + /// Returns the operation primitive. + pub fn primitive(&self) -> &Op { + &self.primitive + } + + /// Returns the structural input keys. + pub fn inputs(&self) -> &[GlobalValKey] { + &self.inputs + } + + /// Returns whether this operation belongs to the primal or linear graph. + pub fn mode(&self) -> &OpMode { + &self.mode + } +} + +impl PartialEq for GlobalValKey { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Input(lhs), Self::Input(rhs)) => lhs == rhs, + ( + Self::Derived { + op: lhs_op, + output_slot: lhs_slot, + }, + Self::Derived { + op: rhs_op, + output_slot: rhs_slot, + }, + ) => { + lhs_slot == rhs_slot + && (Arc::ptr_eq(lhs_op, rhs_op) || lhs_op.as_ref() == rhs_op.as_ref()) + } + _ => false, + } + } +} + +impl Eq for GlobalValKey {} + +impl Hash for GlobalValKey { + fn hash(&self, state: &mut H) { + match self { + Self::Input(key) => { + 0u8.hash(state); + key.hash(state); + } + Self::Derived { op, output_slot } => { + 1u8.hash(state); + op.fingerprint.hash(state); + output_slot.hash(state); + } + } + } +} + +impl PartialEq for GlobalOpKey { + fn eq(&self, other: &Self) -> bool { + self.fingerprint == other.fingerprint + && self.primitive == other.primitive + && self.mode == other.mode + && self.inputs == other.inputs + } +} + +impl Eq for GlobalOpKey {} + +impl Hash for GlobalOpKey { + fn hash(&self, state: &mut H) { + self.fingerprint.hash(state); + } +} + +fn fingerprint_op(primitive: &Op, inputs: &[GlobalValKey], mode: &OpMode) -> u64 { + let mut hasher = DefaultHasher::new(); + primitive.hash(&mut hasher); + mode.hash(&mut hasher); + inputs.len().hash(&mut hasher); + for input in inputs { + fingerprint_val(input).hash(&mut hasher); + } + hasher.finish() +} + +fn fingerprint_val(key: &GlobalValKey) -> u64 { + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() } diff --git a/tests/scalar_tests.rs b/tests/scalar_tests.rs index 9a26912..7ec295b 100644 --- a/tests/scalar_tests.rs +++ b/tests/scalar_tests.rs @@ -1,5 +1,7 @@ mod common; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use common::ScalarOp; @@ -74,14 +76,14 @@ fn interner_get_returns_none_for_unknown() { fn interner_derived_key() { let mut interner = KeyInterner::::new(); let key = GlobalValKey::::Derived { - op: GlobalOpKey { - primitive: ScalarOp::Add, - inputs: vec![ + op: Arc::new(GlobalOpKey::new( + ScalarOp::Add, + vec![ GlobalValKey::Input("x".to_string()), GlobalValKey::Input("y".to_string()), ], - mode: OpMode::Primal, - }, + OpMode::Primal, + )), output_slot: 0, }; let id = interner.intern(key.clone()); @@ -89,6 +91,35 @@ fn interner_derived_key() { assert_eq!(interner.get(&key), Some(id)); } +#[test] +fn derived_keys_with_distinct_op_arcs_are_structurally_equal() { + let inputs = vec![ + GlobalValKey::Input("x".to_string()), + GlobalValKey::Input("y".to_string()), + ]; + let lhs = GlobalValKey::::Derived { + op: Arc::new(GlobalOpKey::new( + ScalarOp::Add, + inputs.clone(), + OpMode::Primal, + )), + output_slot: 0, + }; + let rhs = GlobalValKey::::Derived { + op: Arc::new(GlobalOpKey::new(ScalarOp::Add, inputs, OpMode::Primal)), + output_slot: 0, + }; + + assert_eq!(lhs, rhs); + assert_eq!(hash_key(&lhs), hash_key(&rhs)); +} + +fn hash_key(key: &GlobalValKey) -> u64 { + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() +} + // === Fragment tests === #[test] @@ -125,14 +156,14 @@ fn fragment_builder_add_op() { // Verify GlobalValKey structure let expected_key = GlobalValKey::Derived { - op: GlobalOpKey { - primitive: ScalarOp::Add, - inputs: vec![ + op: Arc::new(GlobalOpKey::new( + ScalarOp::Add, + vec![ GlobalValKey::Input("x".to_string()), GlobalValKey::Input("y".to_string()), ], - mode: OpMode::Primal, - }, + OpMode::Primal, + )), output_slot: 0, }; assert_eq!(frag.vals()[sum_id].key, expected_key);