From 7c90bc9afae1157bdba5cce43ad14193a20d9ffc Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Mon, 1 Jun 2026 19:23:25 +0900 Subject: [PATCH] rename graph API terminology --- README.md | 8 +- src/compile.rs | 24 +-- src/eval.rs | 8 +- src/fragment.rs | 196 ------------------ src/graph.rs | 241 ++++++++++++++++++++++ src/interner.rs | 28 +-- src/lib.rs | 6 +- src/materialize.rs | 116 +++++------ src/resolve.rs | 92 ++++----- src/traits.rs | 82 ++------ src/types.rs | 124 +++++------ tests/common/mod.rs | 10 +- tests/scalar_tests.rs | 470 +++++++++++++++++++++++------------------- 13 files changed, 731 insertions(+), 674 deletions(-) delete mode 100644 src/fragment.rs create mode 100644 src/graph.rs diff --git a/README.md b/README.md index 207ec3a..b8f7154 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,17 @@ # computegraph-rs -AD-agnostic tensor computation graph engine in Rust. +Operation-agnostic computation graph engine in Rust. -Provides fragment-based graph construction, logical resolution, +Provides graph construction, logical resolution, physical materialization, SSA compilation, and evaluation. -Fully generic over `Op: GraphOp` — never references specific primitives. +Fully generic over `Operation: GraphOperation`; it never references specific +primitive operation sets. ## Part of the tensor4all v2 stack ```text computegraph-rs ← this crate -chainrules-rs ← AD trait definitions tidu-rs ← AD graph transforms tenferro-rs ← concrete tensor primitives ``` diff --git a/src/compile.rs b/src/compile.rs index 0d62c43..e576dae 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -1,18 +1,18 @@ use std::collections::HashMap; use crate::materialize::MaterializedGraph; -use crate::traits::GraphOp; -use crate::types::GlobalValKey; +use crate::traits::GraphOperation; +use crate::types::ValueKey; /// A single instruction in the compiled program. -pub struct Instruction { - pub op: Op, +pub struct Instruction { + pub operation: Op, pub inputs: Vec, pub outputs: Vec, } /// SSA-form compiled program. Each slot is written exactly once. -pub struct CompiledProgram { +pub struct CompiledProgram { pub instructions: Vec>, pub input_slots: Vec, pub output_slots: Vec, @@ -20,27 +20,27 @@ pub struct CompiledProgram { } /// Compiles a materialized graph into an SSA instruction sequence. -pub fn compile(graph: &MaterializedGraph) -> CompiledProgram { +pub fn compile(graph: &MaterializedGraph) -> CompiledProgram { let instructions = graph - .ops + .operations .iter() .map(|op_node| Instruction { - op: op_node.op.clone(), + operation: op_node.operation.clone(), inputs: op_node.inputs.clone(), outputs: op_node.outputs.clone(), }) .collect(); let input_slots = graph - .vals + .values .iter() .enumerate() .filter(|(_, val)| val.producer.is_none()) .map(|(index, _)| index) .collect(); - let key_to_index: HashMap<&GlobalValKey, usize> = graph - .vals + let key_to_index: HashMap<&ValueKey, usize> = graph + .values .iter() .enumerate() .map(|(index, val)| (&val.key, index)) @@ -67,6 +67,6 @@ pub fn compile(graph: &MaterializedGraph) -> CompiledProgram CompiledProgram { +impl CompiledProgram { /// Executes the compiled program with the given inputs. pub fn eval(&self, ctx: &mut Op::Context, inputs: &[&Op::Operand]) -> Vec { assert_eq!( @@ -31,12 +31,12 @@ impl CompiledProgram { }) .collect(); - let outputs = instruction.op.eval(ctx, &input_vals); + let outputs = instruction.operation.eval(ctx, &input_vals); assert_eq!( outputs.len(), instruction.outputs.len(), "operation {:?} produced {} outputs, expected {}", - instruction.op, + instruction.operation, outputs.len(), instruction.outputs.len() ); diff --git a/src/fragment.rs b/src/fragment.rs deleted file mode 100644 index 8c7e62d..0000000 --- a/src/fragment.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::sync::Arc; - -use crate::traits::GraphOp; -use crate::types::{GlobalOpKey, GlobalValKey, LocalOpId, LocalValId, OpMode, ValRef}; - -/// A value node in a fragment. -pub struct ValNode { - /// Cross-fragment structural identity. - pub key: GlobalValKey, - /// `None` for fragment inputs; `Some((op_id, output_slot))` for produced values. - pub producer: Option<(LocalOpId, usize)>, -} - -/// An operation node in a fragment. -pub struct OpNode { - pub op: Op, - pub inputs: Vec>, - pub outputs: Vec, - pub mode: OpMode, -} - -/// The unit of graph construction. -pub struct Fragment { - pub(crate) vals: Vec>, - pub(crate) ops: Vec>, - pub(crate) inputs: Vec, - pub(crate) outputs: Vec, - pub(crate) parents: Vec>>, -} - -impl Fragment { - /// Returns all value nodes in this fragment. - pub fn vals(&self) -> &[ValNode] { - &self.vals - } - - /// Returns all operation nodes in this fragment. - pub fn ops(&self) -> &[OpNode] { - &self.ops - } - - /// Returns the fragment-local input value ids. - pub fn inputs(&self) -> &[LocalValId] { - &self.inputs - } - - /// Returns the fragment-local output value ids. - pub fn outputs(&self) -> &[LocalValId] { - &self.outputs - } - - /// Returns the parent fragments referenced by this fragment. - pub fn parents(&self) -> &[Arc>] { - &self.parents - } -} - -/// Builder for constructing fragments incrementally. -pub struct FragmentBuilder { - vals: Vec>, - ops: Vec>, - inputs: Vec, - outputs: Vec, - parents: Vec>>, - local_keys: Vec>, -} - -impl FragmentBuilder { - /// Creates an empty fragment builder. - pub fn new() -> Self { - Self { - vals: Vec::new(), - ops: Vec::new(), - inputs: Vec::new(), - outputs: Vec::new(), - parents: Vec::new(), - local_keys: Vec::new(), - } - } - - /// Adds a fragment input and returns its local id. - pub fn add_input(&mut self, key: Op::InputKey) -> LocalValId { - let val_id = self.vals.len(); - let global_key = GlobalValKey::Input(key); - self.vals.push(ValNode { - key: global_key.clone(), - producer: None, - }); - self.local_keys.push(global_key); - self.inputs.push(val_id); - val_id - } - - /// Adds an operation node and returns the local ids for each output. - pub fn add_op(&mut self, op: Op, inputs: Vec>, mode: OpMode) -> Vec { - assert_eq!( - inputs.len(), - op.n_inputs(), - "operation {:?} expected {} inputs, got {}", - op, - op.n_inputs(), - inputs.len() - ); - - let n_outputs = op.n_outputs(); - assert!( - n_outputs <= u8::MAX as usize + 1, - "operation {:?} has too many outputs for GlobalValKey: {}", - op, - n_outputs - ); - - let op_id = self.ops.len(); - let global_inputs: Vec> = inputs - .iter() - .map(|input| self.resolve_input_key(input)) - .collect(); - - let global_op_key = Arc::new(GlobalOpKey::new(op.clone(), global_inputs, mode.clone())); - - let mut output_ids = Vec::with_capacity(n_outputs); - for slot in 0..n_outputs { - let val_id = self.vals.len(); - let key = GlobalValKey::Derived { - op: Arc::clone(&global_op_key), - output_slot: slot as u8, - }; - self.vals.push(ValNode { - key: key.clone(), - producer: Some((op_id, slot)), - }); - self.local_keys.push(key); - output_ids.push(val_id); - } - - self.ops.push(OpNode { - op, - inputs, - outputs: output_ids.clone(), - mode, - }); - - output_ids - } - - /// Declares the fragment outputs. - pub fn set_outputs(&mut self, outputs: Vec) { - for &output in &outputs { - assert!( - output < self.vals.len(), - "unknown local output value id {}", - output - ); - } - self.outputs = outputs; - } - - /// Registers a parent fragment for external reference resolution. - pub fn add_parent(&mut self, parent: Arc>) { - self.parents.push(parent); - } - - /// Returns the global key for a local value id. - pub fn global_key(&self, local_id: LocalValId) -> &GlobalValKey { - assert!( - local_id < self.local_keys.len(), - "unknown local value id {}", - local_id - ); - &self.local_keys[local_id] - } - - /// Consumes the builder and produces a fragment. - pub fn build(self) -> Fragment { - Fragment { - vals: self.vals, - ops: self.ops, - inputs: self.inputs, - outputs: self.outputs, - parents: self.parents, - } - } - - fn resolve_input_key(&self, input: &ValRef) -> GlobalValKey { - match input { - ValRef::Local(local_id) => self.global_key(*local_id).clone(), - ValRef::External(key) => key.clone(), - } - } -} - -impl Default for FragmentBuilder { - fn default() -> Self { - Self::new() - } -} diff --git a/src/graph.rs b/src/graph.rs new file mode 100644 index 0000000..4781c4c --- /dev/null +++ b/src/graph.rs @@ -0,0 +1,241 @@ +use std::sync::Arc; + +use crate::traits::GraphOperation; +use crate::types::{ + LocalOperationId, LocalValueId, OperationKey, OperationRole, ValueKey, ValueRef, +}; + +/// A value node in a graph. +pub struct ValueNode { + /// Cross-graph structural identity. + pub key: ValueKey, + /// `None` for graph inputs; `Some((op_id, output_slot))` for produced values. + pub producer: Option<(LocalOperationId, usize)>, +} + +/// An operation node in a graph. +pub struct OperationNode { + pub operation: Op, + pub inputs: Vec>, + pub outputs: Vec, + pub role: OperationRole, +} + +/// The unit of graph construction. +/// +/// # Examples +/// +/// ``` +/// use computegraph::graph::GraphBuilder; +/// use computegraph::{GraphOperation, OperationRole, ValueRef}; +/// +/// #[derive(Clone, Debug, Hash, PartialEq, Eq)] +/// enum IdentityOp { +/// Identity, +/// } +/// +/// impl GraphOperation for IdentityOp { +/// type Operand = f64; +/// type Context = (); +/// type InputKey = &'static str; +/// +/// fn input_count(&self) -> usize { 1 } +/// fn output_count(&self) -> usize { 1 } +/// } +/// +/// let mut builder = GraphBuilder::::new(); +/// let x = builder.add_input("x"); +/// let y = builder.add_operation( +/// IdentityOp::Identity, +/// vec![ValueRef::Local(x)], +/// OperationRole::Primary, +/// ); +/// builder.set_outputs(y.clone()); +/// let graph = builder.build(); +/// +/// assert_eq!(graph.inputs(), &[x]); +/// assert_eq!(graph.outputs(), y.as_slice()); +/// ``` +pub struct Graph { + pub(crate) values: Vec>, + pub(crate) operations: Vec>, + pub(crate) inputs: Vec, + pub(crate) outputs: Vec, + pub(crate) parents: Vec>>, +} + +impl Graph { + /// Returns all value nodes in this graph. + pub fn values(&self) -> &[ValueNode] { + &self.values + } + + /// Returns all operation nodes in this graph. + pub fn operations(&self) -> &[OperationNode] { + &self.operations + } + + /// Returns the graph-local input value ids. + pub fn inputs(&self) -> &[LocalValueId] { + &self.inputs + } + + /// Returns the graph-local output value ids. + pub fn outputs(&self) -> &[LocalValueId] { + &self.outputs + } + + /// Returns the parent graphs referenced by this graph. + pub fn parents(&self) -> &[Arc>] { + &self.parents + } +} + +/// Builder for constructing graphs incrementally. +pub struct GraphBuilder { + values: Vec>, + operations: Vec>, + inputs: Vec, + outputs: Vec, + parents: Vec>>, + local_keys: Vec>, +} + +impl GraphBuilder { + /// Creates an empty graph builder. + pub fn new() -> Self { + Self { + values: Vec::new(), + operations: Vec::new(), + inputs: Vec::new(), + outputs: Vec::new(), + parents: Vec::new(), + local_keys: Vec::new(), + } + } + + /// Adds a graph input and returns its local id. + pub fn add_input(&mut self, key: Op::InputKey) -> LocalValueId { + let val_id = self.values.len(); + let global_key = ValueKey::Input(key); + self.values.push(ValueNode { + key: global_key.clone(), + producer: None, + }); + self.local_keys.push(global_key); + self.inputs.push(val_id); + val_id + } + + /// Adds an operation node and returns the local ids for each output. + pub fn add_operation( + &mut self, + operation: Op, + inputs: Vec>, + role: OperationRole, + ) -> Vec { + assert_eq!( + inputs.len(), + operation.input_count(), + "operation {:?} expected {} inputs, got {}", + operation, + operation.input_count(), + inputs.len() + ); + + let output_count = operation.output_count(); + assert!( + output_count <= u8::MAX as usize + 1, + "operation {:?} has too many outputs for ValueKey: {}", + operation, + output_count + ); + + let op_id = self.operations.len(); + let global_inputs: Vec> = inputs + .iter() + .map(|input| self.resolve_input_key(input)) + .collect(); + + let global_op_key = Arc::new(OperationKey::new( + operation.clone(), + global_inputs, + role.clone(), + )); + + let mut output_ids = Vec::with_capacity(output_count); + for slot in 0..output_count { + let val_id = self.values.len(); + let key = ValueKey::Derived { + operation: Arc::clone(&global_op_key), + output_slot: slot as u8, + }; + self.values.push(ValueNode { + key: key.clone(), + producer: Some((op_id, slot)), + }); + self.local_keys.push(key); + output_ids.push(val_id); + } + + self.operations.push(OperationNode { + operation, + inputs, + outputs: output_ids.clone(), + role, + }); + + output_ids + } + + /// Declares the graph outputs. + pub fn set_outputs(&mut self, outputs: Vec) { + for &output in &outputs { + assert!( + output < self.values.len(), + "unknown local output value id {}", + output + ); + } + self.outputs = outputs; + } + + /// Registers a parent graph for external reference resolution. + pub fn add_parent(&mut self, parent: Arc>) { + self.parents.push(parent); + } + + /// Returns the global key for a local value id. + pub fn global_key(&self, local_id: LocalValueId) -> &ValueKey { + assert!( + local_id < self.local_keys.len(), + "unknown local value id {}", + local_id + ); + &self.local_keys[local_id] + } + + /// Consumes the builder and produces a graph. + pub fn build(self) -> Graph { + Graph { + values: self.values, + operations: self.operations, + inputs: self.inputs, + outputs: self.outputs, + parents: self.parents, + } + } + + fn resolve_input_key(&self, input: &ValueRef) -> ValueKey { + match input { + ValueRef::Local(local_id) => self.global_key(*local_id).clone(), + ValueRef::External(key) => key.clone(), + } + } +} + +impl Default for GraphBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/src/interner.rs b/src/interner.rs index 3b5cba8..0ec3161 100644 --- a/src/interner.rs +++ b/src/interner.rs @@ -1,20 +1,20 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; -use crate::traits::GraphOp; -use crate::types::GlobalValKey; +use crate::traits::GraphOperation; +use crate::types::ValueKey; -/// Interned identity for O(1) equality comparison of [`GlobalValKey`]. +/// Interned identity for O(1) equality comparison of [`ValueKey`]. #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] -pub struct ValKeyId(u32); +pub struct ValueKeyId(u32); -/// Maps [`GlobalValKey`] to [`ValKeyId`] for fast equality and deduplication. -pub struct KeyInterner { - map: HashMap, ValKeyId>, - keys: Vec>, +/// Maps [`ValueKey`] to [`ValueKeyId`] for fast equality and deduplication. +pub struct ValueKeyInterner { + map: HashMap, ValueKeyId>, + keys: Vec>, } -impl KeyInterner { +impl ValueKeyInterner { /// Creates an empty interner. pub fn new() -> Self { Self { @@ -24,7 +24,7 @@ impl KeyInterner { } /// Interns a key, returning its unique id. - pub fn intern(&mut self, key: GlobalValKey) -> ValKeyId { + pub fn intern(&mut self, key: ValueKey) -> ValueKeyId { match self.map.entry(key.clone()) { Entry::Occupied(entry) => *entry.get(), Entry::Vacant(entry) => { @@ -33,7 +33,7 @@ impl KeyInterner { "too many interned value keys: {}", self.keys.len() ); - let id = ValKeyId(self.keys.len() as u32); + let id = ValueKeyId(self.keys.len() as u32); self.keys.push(key); entry.insert(id); id @@ -42,12 +42,12 @@ impl KeyInterner { } /// Looks up the id for a key without interning it. - pub fn get(&self, key: &GlobalValKey) -> Option { + pub fn get(&self, key: &ValueKey) -> Option { self.map.get(key).copied() } /// Retrieves the full key from an id. - pub fn resolve(&self, id: ValKeyId) -> &GlobalValKey { + pub fn resolve(&self, id: ValueKeyId) -> &ValueKey { let index = id.0 as usize; assert!( index < self.keys.len(), @@ -58,7 +58,7 @@ impl KeyInterner { } } -impl Default for KeyInterner { +impl Default for ValueKeyInterner { fn default() -> Self { Self::new() } diff --git a/src/lib.rs b/src/lib.rs index a53fd06..103c191 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,12 @@ pub mod compile; mod eval; -pub mod fragment; +pub mod graph; pub mod interner; pub mod materialize; pub mod resolve; pub mod traits; pub mod types; -pub use traits::{EvalGraphOp, GraphOp, OpEmitter}; -pub use types::{GlobalOpKey, GlobalValKey, LocalOpId, LocalValId, OpMode, ValRef}; +pub use traits::{EvaluableGraphOperation, GraphOperation}; +pub use types::{LocalOperationId, LocalValueId, OperationKey, OperationRole, ValueKey, ValueRef}; diff --git a/src/materialize.rs b/src/materialize.rs index ec85e78..2beb5f4 100644 --- a/src/materialize.rs +++ b/src/materialize.rs @@ -1,80 +1,80 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::resolve::{ResolvedView, ValDef}; -use crate::traits::GraphOp; -use crate::types::{GlobalOpKey, GlobalValKey, OpMode}; +use crate::resolve::{ResolvedView, ValueDef}; +use crate::traits::GraphOperation; +use crate::types::{OperationKey, OperationRole, ValueKey}; /// A value in the materialized graph. -pub struct MaterializedVal { - pub key: GlobalValKey, +pub struct MaterializedValue { + pub key: ValueKey, /// `None` for inputs; `Some((op_index, output_slot))` for produced values. pub producer: Option<(usize, usize)>, } /// An operation in the materialized graph. -pub struct MaterializedOp { - pub op: Op, +pub struct MaterializedOperation { + pub operation: Op, pub inputs: Vec, pub outputs: Vec, - pub mode: OpMode, + pub role: OperationRole, } /// Fully flattened, deduplicated graph ready for compilation. -pub struct MaterializedGraph { - pub vals: Vec>, - pub ops: Vec>, - pub inputs: Vec>, - pub outputs: Vec>, +pub struct MaterializedGraph { + pub values: Vec>, + pub operations: Vec>, + pub inputs: Vec>, + pub outputs: Vec>, } -struct Materializer<'a, Op: GraphOp> { +struct Materializer<'a, Op: GraphOperation> { view: &'a ResolvedView, - val_map: HashMap, usize>, - op_map: HashMap>, usize>, - vals: Vec>, - ops: Vec>, - input_keys: Vec>, + val_map: HashMap, usize>, + op_map: HashMap>, usize>, + values: Vec>, + operations: Vec>, + input_keys: Vec>, } -impl<'a, Op: GraphOp> Materializer<'a, Op> { +impl<'a, Op: GraphOperation> Materializer<'a, Op> { fn new(view: &'a ResolvedView) -> Self { Self { view, val_map: HashMap::new(), op_map: HashMap::new(), - vals: Vec::new(), - ops: Vec::new(), + values: Vec::new(), + operations: Vec::new(), input_keys: Vec::new(), } } - fn visit(&mut self, key: &GlobalValKey) -> usize { + fn visit(&mut self, key: &ValueKey) -> usize { if let Some(&index) = self.val_map.get(key) { return index; } - let resolved = self.view.resolve_val(key); + let resolved = self.view.resolve_value(key); assert!( resolved.is_some(), "key not found in resolved view: {:?}", key ); match resolved { - Some(ValDef::Input { .. }) => self.materialize_input(key), - Some(ValDef::Produced { - op, + Some(ValueDef::Input { .. }) => self.materialize_input(key), + Some(ValueDef::Produced { + operation, input_keys, - mode, + role, output_slot, - }) => self.materialize_produced(op, input_keys, mode, output_slot), + }) => self.materialize_produced(operation, input_keys, role, output_slot), None => unreachable!("asserted above"), } } - fn materialize_input(&mut self, key: &GlobalValKey) -> usize { - let index = self.vals.len(); - self.vals.push(MaterializedVal { + fn materialize_input(&mut self, key: &ValueKey) -> usize { + let index = self.values.len(); + self.values.push(MaterializedValue { key: key.clone(), producer: None, }); @@ -85,27 +85,27 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> { fn materialize_produced( &mut self, - op: Op, - input_keys: Vec>, - mode: OpMode, + operation: Op, + input_keys: Vec>, + role: OperationRole, output_slot: usize, ) -> usize { - let op_key = Arc::new(GlobalOpKey::new( - op.clone(), + let op_key = Arc::new(OperationKey::new( + operation.clone(), input_keys.clone(), - mode.clone(), + role.clone(), )); if self.op_map.contains_key(&op_key) { - let output_key = GlobalValKey::Derived { - op: op_key, + let output_key = ValueKey::Derived { + operation: op_key, output_slot: output_slot as u8, }; let val_index = self.val_map.get(&output_key).copied(); assert!( val_index.is_some(), "materialized op {:?} is missing output slot {}", - op, + operation, output_slot ); return match val_index { @@ -115,37 +115,37 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> { } let materialized_inputs = input_keys.iter().map(|input| self.visit(input)).collect(); - let op_index = self.ops.len(); + let op_index = self.operations.len(); self.op_map.insert(Arc::clone(&op_key), op_index); - self.ops.push(MaterializedOp { - op: op.clone(), + self.operations.push(MaterializedOperation { + operation: operation.clone(), inputs: materialized_inputs, - outputs: Vec::with_capacity(op.n_outputs()), - mode, + outputs: Vec::with_capacity(operation.output_count()), + role, }); - for slot in 0..op.n_outputs() { - let output_key = GlobalValKey::Derived { - op: Arc::clone(&op_key), + for slot in 0..operation.output_count() { + let output_key = ValueKey::Derived { + operation: Arc::clone(&op_key), output_slot: slot as u8, }; - let val_index = self.vals.len(); - self.vals.push(MaterializedVal { + let val_index = self.values.len(); + self.values.push(MaterializedValue { key: output_key.clone(), producer: Some((op_index, slot)), }); self.val_map.insert(output_key, val_index); - self.ops[op_index].outputs.push(val_index); + self.operations[op_index].outputs.push(val_index); } - self.ops[op_index].outputs[output_slot] + self.operations[op_index].outputs[output_slot] } } -/// Flattens resolved fragments into a single materialized graph. -pub fn materialize_merge( +/// Flattens resolved graphs into a single materialized graph. +pub fn materialize_merge( view: &ResolvedView, - outputs: &[GlobalValKey], + outputs: &[ValueKey], ) -> MaterializedGraph { let mut materializer = Materializer::new(view); @@ -154,8 +154,8 @@ pub fn materialize_merge( } MaterializedGraph { - vals: materializer.vals, - ops: materializer.ops, + values: materializer.values, + operations: materializer.operations, inputs: materializer.input_keys, outputs: outputs.to_vec(), } diff --git a/src/resolve.rs b/src/resolve.rs index 15d0e24..960b4c9 100644 --- a/src/resolve.rs +++ b/src/resolve.rs @@ -1,60 +1,60 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::fragment::Fragment; -use crate::traits::GraphOp; -use crate::types::{GlobalValKey, OpMode, ValRef}; +use crate::graph::Graph; +use crate::traits::GraphOperation; +use crate::types::{OperationRole, ValueKey, ValueRef}; /// Definition of a value as seen through the resolver. #[derive(Clone, Debug, PartialEq)] -pub enum ValDef { +pub enum ValueDef { Input { key: Op::InputKey, }, Produced { - op: Op, + operation: Op, /// Inputs resolved to global keys. - input_keys: Vec>, - mode: OpMode, + input_keys: Vec>, + role: OperationRole, output_slot: usize, }, } -/// Trait for resolving [`GlobalValKey`] to its definition. -pub trait Resolver { - fn resolve_val(&self, key: &GlobalValKey) -> Option>; +/// Trait for resolving [`ValueKey`] to its definition. +pub trait Resolver { + fn resolve_value(&self, key: &ValueKey) -> Option>; } -/// Logical traversal view over one or more fragments. -pub struct ResolvedView { - pub roots: Vec>>, +/// Logical traversal view over one or more graphs. +pub struct ResolvedView { + pub roots: Vec>>, resolver: Box>, } -impl ResolvedView { +impl ResolvedView { /// Resolves a global value key to its logical definition. - pub fn resolve_val(&self, key: &GlobalValKey) -> Option> { - self.resolver.resolve_val(key) + pub fn resolve_value(&self, key: &ValueKey) -> Option> { + self.resolver.resolve_value(key) } } -struct HashMapResolver { - map: HashMap, ValDef>, +struct HashMapResolver { + map: HashMap, ValueDef>, } -impl Resolver for HashMapResolver { - fn resolve_val(&self, key: &GlobalValKey) -> Option> { +impl Resolver for HashMapResolver { + fn resolve_value(&self, key: &ValueKey) -> Option> { self.map.get(key).cloned() } } -/// Builds a logical lookup view over fragments and their parent chains. -pub fn resolve(roots: Vec>>) -> ResolvedView { +/// Builds a logical lookup view over graphs and their parent chains. +pub fn resolve(roots: Vec>>) -> ResolvedView { let mut map = HashMap::new(); let mut visited = HashSet::new(); for root in &roots { - walk_fragment(root, &mut map, &mut visited); + walk_graph(root, &mut map, &mut visited); } ResolvedView { @@ -63,21 +63,21 @@ pub fn resolve(roots: Vec>>) -> ResolvedView { } } -fn walk_fragment( - fragment: &Fragment, - map: &mut HashMap, ValDef>, - visited: &mut HashSet<*const Fragment>, +fn walk_graph( + graph: &Graph, + map: &mut HashMap, ValueDef>, + visited: &mut HashSet<*const Graph>, ) { - let fragment_ptr: *const Fragment = fragment; - if !visited.insert(fragment_ptr) { + let graph_ptr: *const Graph = graph; + if !visited.insert(graph_ptr) { return; } - for parent in fragment.parents() { - walk_fragment(parent, map, visited); + for parent in graph.parents() { + walk_graph(parent, map, visited); } - for val in fragment.vals() { + for val in graph.values() { if map.contains_key(&val.key) { continue; } @@ -85,44 +85,44 @@ fn walk_fragment( match val.producer { None => { let input_key = match &val.key { - GlobalValKey::Input(key) => key.clone(), + ValueKey::Input(key) => key.clone(), _ => panic!( - "fragment input value must use GlobalValKey::Input, got {:?}", + "graph input value must use ValueKey::Input, got {:?}", val.key ), }; - map.insert(val.key.clone(), ValDef::Input { key: input_key }); + map.insert(val.key.clone(), ValueDef::Input { key: input_key }); } Some((op_id, output_slot)) => { assert!( - op_id < fragment.ops().len(), + op_id < graph.operations().len(), "value references unknown producer op id {}", op_id ); - let op_node = &fragment.ops()[op_id]; - let input_keys = op_node + let operation_node = &graph.operations()[op_id]; + let input_keys = operation_node .inputs .iter() .map(|input| match input { - ValRef::Local(local_id) => { + ValueRef::Local(local_id) => { assert!( - *local_id < fragment.vals().len(), + *local_id < graph.values().len(), "operation {:?} references unknown local value id {}", - op_node.op, + operation_node.operation, local_id ); - fragment.vals()[*local_id].key.clone() + graph.values()[*local_id].key.clone() } - ValRef::External(key) => key.clone(), + ValueRef::External(key) => key.clone(), }) .collect(); map.insert( val.key.clone(), - ValDef::Produced { - op: op_node.op.clone(), + ValueDef::Produced { + operation: operation_node.operation.clone(), input_keys, - mode: op_node.mode.clone(), + role: operation_node.role.clone(), output_slot, }, ); diff --git a/src/traits.rs b/src/traits.rs index 4e17342..7e849c0 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,113 +1,75 @@ use std::hash::Hash; -use crate::fragment::FragmentBuilder; -use crate::types::{LocalValId, OpMode, ValRef}; - /// Operation node trait. `computegraph` is fully generic over this abstraction. /// -/// `GraphOp` captures the metadata of an operation (input/output counts, -/// associated types) but does **not** include evaluation. See [`EvalGraphOp`] +/// `GraphOperation` captures the metadata of an operation (input/output counts, +/// associated types) but does **not** include evaluation. See [`EvaluableGraphOperation`] /// for the evaluation extension. /// /// # Examples /// -/// ```ignore -/// use computegraph::GraphOp; +/// ``` +/// use computegraph::GraphOperation; /// /// #[derive(Clone, Debug, Hash, PartialEq, Eq)] /// enum AddOp { /// Add, /// } /// -/// impl GraphOp for AddOp { +/// impl GraphOperation for AddOp { /// type Operand = f64; /// type Context = (); /// type InputKey = &'static str; /// -/// fn n_inputs(&self) -> usize { 2 } -/// fn n_outputs(&self) -> usize { 1 } +/// fn input_count(&self) -> usize { 2 } +/// fn output_count(&self) -> usize { 1 } /// } +/// +/// assert_eq!(AddOp::Add.input_count(), 2); /// ``` -pub trait GraphOp: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static { +pub trait GraphOperation: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static { type Operand: Clone + Send + Sync + 'static; type Context; type InputKey: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static; /// Returns the number of inputs consumed by this operation. - fn n_inputs(&self) -> usize; + fn input_count(&self) -> usize; /// Returns the number of outputs produced by this operation. - fn n_outputs(&self) -> usize; + fn output_count(&self) -> usize; } -/// Minimal trait for emitting operations into a computation context. -/// -/// AD transpose rules use only this interface, enabling both graph-building -/// (`FragmentBuilder`) and eager execution through the same code. +/// Extension trait that adds evaluation capability to a [`GraphOperation`]. /// /// # Examples /// -/// ```ignore -/// use computegraph::{FragmentBuilder, GraphOp, OpEmitter, OpMode, ValRef}; -/// -/// #[derive(Clone, Debug, Hash, PartialEq, Eq)] -/// enum UnaryOp { -/// Identity, -/// } -/// -/// impl GraphOp for UnaryOp { -/// type Operand = f64; -/// type Context = (); -/// type InputKey = &'static str; -/// -/// fn n_inputs(&self) -> usize { 1 } -/// fn n_outputs(&self) -> usize { 1 } -/// } -/// -/// let mut builder = FragmentBuilder::::new(); -/// let x = builder.add_input("x"); -/// let ys = builder.add_op(UnaryOp::Identity, vec![ValRef::Local(x)], OpMode::Primal); -/// assert_eq!(ys.len(), 1); /// ``` -pub trait OpEmitter { - /// Emits an operation with the given inputs and mode, returning output ids. - fn add_op(&mut self, op: Op, inputs: Vec>, mode: OpMode) -> Vec; -} - -impl OpEmitter for FragmentBuilder { - fn add_op(&mut self, op: Op, inputs: Vec>, mode: OpMode) -> Vec { - FragmentBuilder::add_op(self, op, inputs, mode) - } -} - -/// Extension trait that adds evaluation capability to a [`GraphOp`]. -/// -/// # Examples -/// -/// ```ignore -/// use computegraph::{GraphOp, EvalGraphOp}; +/// use computegraph::{EvaluableGraphOperation, GraphOperation}; /// /// #[derive(Clone, Debug, Hash, PartialEq, Eq)] /// enum AddOp { /// Add, /// } /// -/// impl GraphOp for AddOp { +/// impl GraphOperation for AddOp { /// type Operand = f64; /// type Context = (); /// type InputKey = &'static str; /// -/// fn n_inputs(&self) -> usize { 2 } -/// fn n_outputs(&self) -> usize { 1 } +/// fn input_count(&self) -> usize { 2 } +/// fn output_count(&self) -> usize { 1 } /// } /// -/// impl EvalGraphOp for AddOp { +/// impl EvaluableGraphOperation for AddOp { /// fn eval(&self, _ctx: &mut Self::Context, inputs: &[&Self::Operand]) -> Vec { /// vec![inputs[0] + inputs[1]] /// } /// } +/// +/// let result = AddOp::Add.eval(&mut (), &[&3.0, &4.0]); +/// assert_eq!(result, vec![7.0]); /// ``` -pub trait EvalGraphOp: GraphOp { +pub trait EvaluableGraphOperation: GraphOperation { /// Evaluates the operation given concrete input operands. fn eval(&self, ctx: &mut Self::Context, inputs: &[&Self::Operand]) -> Vec; } diff --git a/src/types.rs b/src/types.rs index 2b7462c..91d8d1d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -2,49 +2,49 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::traits::GraphOp; +use crate::traits::GraphOperation; -/// Fragment-local value identifier. -pub type LocalValId = usize; +/// Graph-local value identifier. +pub type LocalValueId = usize; -/// Fragment-local operation identifier. -pub type LocalOpId = usize; +/// Graph-local operation identifier. +pub type LocalOperationId = usize; -/// Distinguishes primal nodes from linear (AD-generated) nodes. +/// Describes the role an operation plays in a graph. #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum OpMode { - Primal, - Linear { active_mask: Vec }, +pub enum OperationRole { + Primary, + Linearized { active_mask: Vec }, } -/// Reference to a value: either local to the current fragment or external. +/// Reference to a value: either local to the current graph or external. #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ValRef { - Local(LocalValId), - External(GlobalValKey), +pub enum ValueRef { + Local(LocalValueId), + External(ValueKey), } -/// Cross-fragment structural identity for a value. +/// Cross-graph structural identity for a value. #[derive(Clone, Debug)] -pub enum GlobalValKey { +pub enum ValueKey { Input(Op::InputKey), Derived { /// Shared structural identity of the operation that produced this value. - op: Arc>, + operation: Arc>, output_slot: u8, }, } -/// Cross-fragment structural identity for an operation. +/// Cross-graph structural identity for an operation. /// -/// `GlobalOpKey` caches a structural fingerprint so maps keyed by recursively +/// `OperationKey` caches a structural fingerprint so maps keyed by recursively /// derived values can avoid repeatedly re-hashing the whole input tree. Equality /// still checks the full structure after the fingerprint prefilter. #[derive(Clone, Debug)] -pub struct GlobalOpKey { - primitive: Op, - inputs: Vec>, - mode: OpMode, +pub struct OperationKey { + operation: Op, + inputs: Vec>, + role: OperationRole, /// Cached hash prefilter for recursively structural keys. /// /// This is not an identity proof: equality still compares the full @@ -53,41 +53,42 @@ pub struct GlobalOpKey { fingerprint: u64, } -impl GlobalOpKey { +impl OperationKey { /// Builds an operation key and precomputes its structural fingerprint. /// /// # Examples /// - /// ```ignore - /// use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, OpMode}; + /// ``` + /// use computegraph::{GraphOperation, OperationKey, OperationRole, ValueKey}; /// /// #[derive(Clone, Debug, Hash, PartialEq, Eq)] /// enum Op { /// Add, /// } /// - /// impl GraphOp for Op { + /// impl GraphOperation for Op { /// type Operand = f64; /// type Context = (); /// type InputKey = &'static str; /// - /// fn n_inputs(&self) -> usize { 2 } - /// fn n_outputs(&self) -> usize { 1 } + /// fn input_count(&self) -> usize { 2 } + /// fn output_count(&self) -> usize { 1 } /// } /// - /// let key = GlobalOpKey::new( + /// let key = OperationKey::new( /// Op::Add, - /// vec![GlobalValKey::Input("x"), GlobalValKey::Input("y")], - /// OpMode::Primal, + /// vec![ValueKey::Input("x"), ValueKey::Input("y")], + /// OperationRole::Primary, /// ); /// assert_eq!(key.inputs().len(), 2); + /// assert_eq!(key.role(), &OperationRole::Primary); /// ``` - pub fn new(primitive: Op, inputs: Vec>, mode: OpMode) -> Self { - let fingerprint = fingerprint_op(&primitive, &inputs, &mode); + pub fn new(operation: Op, inputs: Vec>, role: OperationRole) -> Self { + let fingerprint = fingerprint_operation(&operation, &inputs, &role); Self { - primitive, + operation, inputs, - mode, + role, fingerprint, } } @@ -97,33 +98,33 @@ impl GlobalOpKey { self.fingerprint } - /// Returns the operation primitive. - pub fn primitive(&self) -> &Op { - &self.primitive + /// Returns the operation. + pub fn operation(&self) -> &Op { + &self.operation } /// Returns the structural input keys. - pub fn inputs(&self) -> &[GlobalValKey] { + pub fn inputs(&self) -> &[ValueKey] { &self.inputs } - /// Returns whether this operation belongs to the primal or linear graph. - pub fn mode(&self) -> &OpMode { - &self.mode + /// Returns the role of this operation in the graph. + pub fn role(&self) -> &OperationRole { + &self.role } } -impl PartialEq for GlobalValKey { +impl PartialEq for ValueKey { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::Input(lhs), Self::Input(rhs)) => lhs == rhs, ( Self::Derived { - op: lhs_op, + operation: lhs_op, output_slot: lhs_slot, }, Self::Derived { - op: rhs_op, + operation: rhs_op, output_slot: rhs_slot, }, ) => { @@ -135,45 +136,52 @@ impl PartialEq for GlobalValKey { } } -impl Eq for GlobalValKey {} +impl Eq for ValueKey {} -impl Hash for GlobalValKey { +impl Hash for ValueKey { fn hash(&self, state: &mut H) { match self { Self::Input(key) => { 0u8.hash(state); key.hash(state); } - Self::Derived { op, output_slot } => { + Self::Derived { + operation, + output_slot, + } => { 1u8.hash(state); - op.fingerprint.hash(state); + operation.fingerprint.hash(state); output_slot.hash(state); } } } } -impl PartialEq for GlobalOpKey { +impl PartialEq for OperationKey { fn eq(&self, other: &Self) -> bool { self.fingerprint == other.fingerprint - && self.primitive == other.primitive - && self.mode == other.mode + && self.operation == other.operation + && self.role == other.role && self.inputs == other.inputs } } -impl Eq for GlobalOpKey {} +impl Eq for OperationKey {} -impl Hash for GlobalOpKey { +impl Hash for OperationKey { fn hash(&self, state: &mut H) { self.fingerprint.hash(state); } } -fn fingerprint_op(primitive: &Op, inputs: &[GlobalValKey], mode: &OpMode) -> u64 { +fn fingerprint_operation( + operation: &Op, + inputs: &[ValueKey], + role: &OperationRole, +) -> u64 { let mut hasher = DefaultHasher::new(); - primitive.hash(&mut hasher); - mode.hash(&mut hasher); + operation.hash(&mut hasher); + role.hash(&mut hasher); inputs.len().hash(&mut hasher); for input in inputs { fingerprint_val(input).hash(&mut hasher); @@ -181,7 +189,7 @@ fn fingerprint_op(primitive: &Op, inputs: &[GlobalValKey], mode hasher.finish() } -fn fingerprint_val(key: &GlobalValKey) -> u64 { +fn fingerprint_val(key: &ValueKey) -> u64 { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); hasher.finish() diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d6b230c..907fba1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,4 +1,4 @@ -use computegraph::{EvalGraphOp, GraphOp}; +use computegraph::{EvaluableGraphOperation, GraphOperation}; /// Scalar operations for testing. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -10,19 +10,19 @@ pub enum ScalarOp { Dup, } -impl GraphOp for ScalarOp { +impl GraphOperation for ScalarOp { type Operand = f64; type Context = (); type InputKey = String; - fn n_inputs(&self) -> usize { + fn input_count(&self) -> usize { match self { ScalarOp::Add | ScalarOp::Mul => 2, ScalarOp::Exp | ScalarOp::Neg | ScalarOp::Dup => 1, } } - fn n_outputs(&self) -> usize { + fn output_count(&self) -> usize { match self { ScalarOp::Dup => 2, _ => 1, @@ -30,7 +30,7 @@ impl GraphOp for ScalarOp { } } -impl EvalGraphOp for ScalarOp { +impl EvaluableGraphOperation for ScalarOp { fn eval(&self, _ctx: &mut (), inputs: &[&f64]) -> Vec { match self { ScalarOp::Add => vec![inputs[0] + inputs[1]], diff --git a/tests/scalar_tests.rs b/tests/scalar_tests.rs index 7ec295b..9fe54a4 100644 --- a/tests/scalar_tests.rs +++ b/tests/scalar_tests.rs @@ -6,19 +6,21 @@ use std::sync::Arc; use common::ScalarOp; use computegraph::compile::compile; -use computegraph::fragment::FragmentBuilder; -use computegraph::interner::KeyInterner; +use computegraph::graph::GraphBuilder; +use computegraph::interner::ValueKeyInterner; use computegraph::materialize::materialize_merge; -use computegraph::resolve::{resolve, ValDef}; -use computegraph::{EvalGraphOp, GlobalOpKey, GlobalValKey, GraphOp, OpMode, ValRef}; +use computegraph::resolve::{resolve, ValueDef}; +use computegraph::{ + EvaluableGraphOperation, GraphOperation, OperationKey, OperationRole, ValueKey, ValueRef, +}; // === ScalarOp smoke tests === #[test] fn scalar_op_eval_add() { let op = ScalarOp::Add; - assert_eq!(op.n_inputs(), 2); - assert_eq!(op.n_outputs(), 1); + assert_eq!(op.input_count(), 2); + assert_eq!(op.output_count(), 1); let result = op.eval(&mut (), &[&3.0, &4.0]); assert_eq!(result, vec![7.0]); } @@ -33,25 +35,25 @@ fn scalar_op_eval_exp() { #[test] fn scalar_op_eval_dup() { let op = ScalarOp::Dup; - assert_eq!(op.n_outputs(), 2); + assert_eq!(op.output_count(), 2); let result = op.eval(&mut (), &[&5.0]); assert_eq!(result, vec![5.0, 5.0]); } -// === KeyInterner tests === +// === ValueKeyInterner tests === #[test] fn interner_intern_and_resolve() { - let mut interner = KeyInterner::::new(); - let key = GlobalValKey::Input("x".to_string()); + let mut interner = ValueKeyInterner::::new(); + let key = ValueKey::Input("x".to_string()); let id = interner.intern(key.clone()); assert_eq!(interner.resolve(id), &key); } #[test] fn interner_deduplicates() { - let mut interner = KeyInterner::::new(); - let key = GlobalValKey::Input("x".to_string()); + let mut interner = ValueKeyInterner::::new(); + let key = ValueKey::Input("x".to_string()); let id1 = interner.intern(key.clone()); let id2 = interner.intern(key); assert_eq!(id1, id2); @@ -59,30 +61,30 @@ fn interner_deduplicates() { #[test] fn interner_distinct_keys_get_distinct_ids() { - let mut interner = KeyInterner::::new(); - let id_x = interner.intern(GlobalValKey::Input("x".to_string())); - let id_y = interner.intern(GlobalValKey::Input("y".to_string())); + let mut interner = ValueKeyInterner::::new(); + let id_x = interner.intern(ValueKey::Input("x".to_string())); + let id_y = interner.intern(ValueKey::Input("y".to_string())); assert_ne!(id_x, id_y); } #[test] fn interner_get_returns_none_for_unknown() { - let interner = KeyInterner::::new(); - let key = GlobalValKey::Input("x".to_string()); + let interner = ValueKeyInterner::::new(); + let key = ValueKey::Input("x".to_string()); assert_eq!(interner.get(&key), None); } #[test] fn interner_derived_key() { - let mut interner = KeyInterner::::new(); - let key = GlobalValKey::::Derived { - op: Arc::new(GlobalOpKey::new( + let mut interner = ValueKeyInterner::::new(); + let key = ValueKey::::Derived { + operation: Arc::new(OperationKey::new( ScalarOp::Add, vec![ - GlobalValKey::Input("x".to_string()), - GlobalValKey::Input("y".to_string()), + ValueKey::Input("x".to_string()), + ValueKey::Input("y".to_string()), ], - OpMode::Primal, + OperationRole::Primary, )), output_slot: 0, }; @@ -94,19 +96,23 @@ fn interner_derived_key() { #[test] fn derived_keys_with_distinct_op_arcs_are_structurally_equal() { let inputs = vec![ - GlobalValKey::Input("x".to_string()), - GlobalValKey::Input("y".to_string()), + ValueKey::Input("x".to_string()), + ValueKey::Input("y".to_string()), ]; - let lhs = GlobalValKey::::Derived { - op: Arc::new(GlobalOpKey::new( + let lhs = ValueKey::::Derived { + operation: Arc::new(OperationKey::new( ScalarOp::Add, inputs.clone(), - OpMode::Primal, + OperationRole::Primary, )), output_slot: 0, }; - let rhs = GlobalValKey::::Derived { - op: Arc::new(GlobalOpKey::new(ScalarOp::Add, inputs, OpMode::Primal)), + let rhs = ValueKey::::Derived { + operation: Arc::new(OperationKey::new( + ScalarOp::Add, + inputs, + OperationRole::Primary, + )), output_slot: 0, }; @@ -114,103 +120,107 @@ fn derived_keys_with_distinct_op_arcs_are_structurally_equal() { assert_eq!(hash_key(&lhs), hash_key(&rhs)); } -fn hash_key(key: &GlobalValKey) -> u64 { +fn hash_key(key: &ValueKey) -> u64 { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); hasher.finish() } -// === Fragment tests === +// === Graph tests === #[test] -fn fragment_builder_single_input() { - let mut builder = FragmentBuilder::::new(); +fn graph_builder_single_input() { + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); assert_eq!(x, 0); builder.set_outputs(vec![x]); - let frag = builder.build(); - assert_eq!(frag.inputs().len(), 1); - assert_eq!(frag.outputs().len(), 1); - assert_eq!(frag.vals()[x].key, GlobalValKey::Input("x".to_string())); - assert!(frag.vals()[x].producer.is_none()); + let graph = builder.build(); + assert_eq!(graph.inputs().len(), 1); + assert_eq!(graph.outputs().len(), 1); + assert_eq!(graph.values()[x].key, ValueKey::Input("x".to_string())); + assert!(graph.values()[x].producer.is_none()); } #[test] -fn fragment_builder_add_op() { - let mut builder = FragmentBuilder::::new(); +fn graph_builder_add_operation() { + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let y = builder.add_input("y".to_string()); - let outputs = builder.add_op( + let outputs = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(y)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(y)], + OperationRole::Primary, ); assert_eq!(outputs.len(), 1); let sum_id = outputs[0]; builder.set_outputs(vec![sum_id]); - let frag = builder.build(); + let graph = builder.build(); - assert_eq!(frag.ops().len(), 1); - assert_eq!(frag.ops()[0].op, ScalarOp::Add); - assert!(frag.vals()[sum_id].producer.is_some()); + assert_eq!(graph.operations().len(), 1); + assert_eq!(graph.operations()[0].operation, ScalarOp::Add); + assert!(graph.values()[sum_id].producer.is_some()); - // Verify GlobalValKey structure - let expected_key = GlobalValKey::Derived { - op: Arc::new(GlobalOpKey::new( + // Verify ValueKey structure + let expected_key = ValueKey::Derived { + operation: Arc::new(OperationKey::new( ScalarOp::Add, vec![ - GlobalValKey::Input("x".to_string()), - GlobalValKey::Input("y".to_string()), + ValueKey::Input("x".to_string()), + ValueKey::Input("y".to_string()), ], - OpMode::Primal, + OperationRole::Primary, )), output_slot: 0, }; - assert_eq!(frag.vals()[sum_id].key, expected_key); + assert_eq!(graph.values()[sum_id].key, expected_key); } #[test] -fn fragment_builder_chain() { +fn graph_builder_chain() { // Build: Exp(Mul(x, a)) - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let a = builder.add_input("a".to_string()); - let mul_out = builder.add_op( + let mul_out = builder.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, ); - let exp_out = builder.add_op( + let exp_out = builder.add_operation( ScalarOp::Exp, - vec![ValRef::Local(mul_out[0])], - OpMode::Primal, + vec![ValueRef::Local(mul_out[0])], + OperationRole::Primary, ); builder.set_outputs(vec![exp_out[0]]); - let frag = builder.build(); + let graph = builder.build(); - assert_eq!(frag.ops().len(), 2); - assert_eq!(frag.vals().len(), 4); // x, a, mul_out, exp_out + assert_eq!(graph.operations().len(), 2); + assert_eq!(graph.values().len(), 4); // x, a, mul_out, exp_out } #[test] -fn fragment_builder_dup_two_outputs() { - let mut builder = FragmentBuilder::::new(); +fn graph_builder_dup_two_outputs() { + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); - let dup_outs = builder.add_op(ScalarOp::Dup, vec![ValRef::Local(x)], OpMode::Primal); + let dup_outs = builder.add_operation( + ScalarOp::Dup, + vec![ValueRef::Local(x)], + OperationRole::Primary, + ); assert_eq!(dup_outs.len(), 2); builder.set_outputs(dup_outs.clone()); - let frag = builder.build(); + let graph = builder.build(); - assert_eq!(frag.outputs().len(), 2); + assert_eq!(graph.outputs().len(), 2); // Both outputs should be Derived with different output_slot - let key0 = &frag.vals()[dup_outs[0]].key; - let key1 = &frag.vals()[dup_outs[1]].key; + let key0 = &graph.values()[dup_outs[0]].key; + let key1 = &graph.values()[dup_outs[1]].key; match (key0, key1) { ( - GlobalValKey::Derived { + ValueKey::Derived { output_slot: s0, .. }, - GlobalValKey::Derived { + ValueKey::Derived { output_slot: s1, .. }, ) => { @@ -224,39 +234,39 @@ fn fragment_builder_dup_two_outputs() { // === Resolve tests === #[test] -fn resolve_single_fragment() { - let mut builder = FragmentBuilder::::new(); +fn resolve_single_graph() { + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let y = builder.add_input("y".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(y)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(y)], + OperationRole::Primary, ); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag.clone()]); + let view = resolve(vec![graph.clone()]); // Input keys should resolve - let x_key = GlobalValKey::Input("x".to_string()); - match view.resolve_val(&x_key).unwrap() { - ValDef::Input { key } => assert_eq!(key, "x"), + let x_key = ValueKey::Input("x".to_string()); + match view.resolve_value(&x_key).unwrap() { + ValueDef::Input { key } => assert_eq!(key, "x"), _ => panic!("expected Input"), } // Derived key should resolve - let sum_key = &frag.vals()[sum[0]].key; - match view.resolve_val(sum_key).unwrap() { - ValDef::Produced { - op, + let sum_key = &graph.values()[sum[0]].key; + match view.resolve_value(sum_key).unwrap() { + ValueDef::Produced { + operation, input_keys, - mode, + role, output_slot, } => { - assert_eq!(op, ScalarOp::Add); + assert_eq!(operation, ScalarOp::Add); assert_eq!(input_keys.len(), 2); - assert_eq!(mode, OpMode::Primal); + assert_eq!(role, OperationRole::Primary); assert_eq!(output_slot, 0); } _ => panic!("expected Produced"), @@ -264,27 +274,27 @@ fn resolve_single_fragment() { } #[test] -fn resolve_external_ref_across_fragments() { - // Fragment F0: x, a, mul = Mul(x, a) - let mut b0 = FragmentBuilder::::new(); +fn resolve_external_ref_across_graphs() { + // Graph F0: x, a, mul = Mul(x, a) + let mut b0 = GraphBuilder::::new(); let x = b0.add_input("x".to_string()); let a = b0.add_input("a".to_string()); - let mul = b0.add_op( + let mul = b0.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, ); let mul_key = b0.global_key(mul[0]).clone(); b0.set_outputs(vec![mul[0]]); let f0 = Arc::new(b0.build()); - // Fragment F1: references F0's mul output via External, applies Exp - let mut b1 = FragmentBuilder::::new(); + // Graph F1: references F0's mul output via External, applies Exp + let mut b1 = GraphBuilder::::new(); b1.add_parent(f0.clone()); - let exp = b1.add_op( + let exp = b1.add_operation( ScalarOp::Exp, - vec![ValRef::External(mul_key.clone())], - OpMode::Primal, + vec![ValueRef::External(mul_key.clone())], + OperationRole::Primary, ); b1.set_outputs(vec![exp[0]]); let f1 = Arc::new(b1.build()); @@ -292,13 +302,17 @@ fn resolve_external_ref_across_fragments() { let view = resolve(vec![f0, f1.clone()]); // mul_key should be resolvable - assert!(view.resolve_val(&mul_key).is_some()); + assert!(view.resolve_value(&mul_key).is_some()); // exp output should be resolvable - let exp_key = &f1.vals()[exp[0]].key; - match view.resolve_val(exp_key).unwrap() { - ValDef::Produced { op, input_keys, .. } => { - assert_eq!(op, ScalarOp::Exp); + let exp_key = &f1.values()[exp[0]].key; + match view.resolve_value(exp_key).unwrap() { + ValueDef::Produced { + operation, + input_keys, + .. + } => { + assert_eq!(operation, ScalarOp::Exp); assert_eq!(input_keys.len(), 1); assert_eq!(input_keys[0], mul_key); } @@ -308,104 +322,108 @@ fn resolve_external_ref_across_fragments() { #[test] fn resolve_unknown_key_returns_none() { - let builder = FragmentBuilder::::new(); - let frag = Arc::new(builder.build()); - let view = resolve(vec![frag]); - let unknown = GlobalValKey::Input("unknown".to_string()); - assert!(view.resolve_val(&unknown).is_none()); + let builder = GraphBuilder::::new(); + let graph = Arc::new(builder.build()); + let view = resolve(vec![graph]); + let unknown = ValueKey::Input("unknown".to_string()); + assert!(view.resolve_value(&unknown).is_none()); } // === Materialize tests === #[test] fn materialize_single_op() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let y = builder.add_input("y".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(y)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(y)], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); - assert_eq!(graph.ops.len(), 1); - assert_eq!(graph.ops[0].op, ScalarOp::Add); - assert_eq!(graph.vals.len(), 3); + assert_eq!(graph.operations.len(), 1); + assert_eq!(graph.operations[0].operation, ScalarOp::Add); + assert_eq!(graph.values.len(), 3); assert_eq!(graph.inputs.len(), 2); assert_eq!(graph.outputs.len(), 1); } #[test] fn materialize_chain() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let a = builder.add_input("a".to_string()); - let mul = builder.add_op( + let mul = builder.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, + ); + let exp = builder.add_operation( + ScalarOp::Exp, + vec![ValueRef::Local(mul[0])], + OperationRole::Primary, ); - let exp = builder.add_op(ScalarOp::Exp, vec![ValRef::Local(mul[0])], OpMode::Primal); let exp_key = builder.global_key(exp[0]).clone(); builder.set_outputs(vec![exp[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[exp_key]); - assert_eq!(graph.ops.len(), 2); - assert_eq!(graph.vals.len(), 4); - assert_eq!(graph.ops[0].op, ScalarOp::Mul); - assert_eq!(graph.ops[1].op, ScalarOp::Exp); + assert_eq!(graph.operations.len(), 2); + assert_eq!(graph.values.len(), 4); + assert_eq!(graph.operations[0].operation, ScalarOp::Mul); + assert_eq!(graph.operations[1].operation, ScalarOp::Exp); } #[test] fn materialize_cse_deduplicates() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(x)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(x)], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); - assert_eq!(graph.vals.len(), 2); - assert_eq!(graph.ops.len(), 1); - assert_eq!(graph.ops[0].inputs[0], graph.ops[0].inputs[1]); + assert_eq!(graph.values.len(), 2); + assert_eq!(graph.operations.len(), 1); + assert_eq!(graph.operations[0].inputs[0], graph.operations[0].inputs[1]); } #[test] -fn materialize_across_fragments() { - let mut b0 = FragmentBuilder::::new(); +fn materialize_across_graphs() { + let mut b0 = GraphBuilder::::new(); let x = b0.add_input("x".to_string()); let a = b0.add_input("a".to_string()); - let mul = b0.add_op( + let mul = b0.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, ); let mul_key = b0.global_key(mul[0]).clone(); b0.set_outputs(vec![mul[0]]); let f0 = Arc::new(b0.build()); - let mut b1 = FragmentBuilder::::new(); + let mut b1 = GraphBuilder::::new(); b1.add_parent(f0.clone()); - let exp = b1.add_op( + let exp = b1.add_operation( ScalarOp::Exp, - vec![ValRef::External(mul_key)], - OpMode::Primal, + vec![ValueRef::External(mul_key)], + OperationRole::Primary, ); let exp_key = b1.global_key(exp[0]).clone(); b1.set_outputs(vec![exp[0]]); @@ -414,27 +432,27 @@ fn materialize_across_fragments() { let view = resolve(vec![f0, f1]); let graph = materialize_merge(&view, &[exp_key]); - assert_eq!(graph.ops.len(), 2); - assert_eq!(graph.vals.len(), 4); + assert_eq!(graph.operations.len(), 2); + assert_eq!(graph.values.len(), 4); } // === Compile + Eval tests === #[test] fn compile_and_eval_add() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let y = builder.add_input("y".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(y)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(y)], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); let prog = compile(&graph); @@ -448,20 +466,24 @@ fn compile_and_eval_add() { #[test] fn compile_and_eval_chain() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let a = builder.add_input("a".to_string()); - let mul = builder.add_op( + let mul = builder.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, + ); + let exp = builder.add_operation( + ScalarOp::Exp, + vec![ValueRef::Local(mul[0])], + OperationRole::Primary, ); - let exp = builder.add_op(ScalarOp::Exp, vec![ValRef::Local(mul[0])], OpMode::Primal); let exp_key = builder.global_key(exp[0]).clone(); builder.set_outputs(vec![exp[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[exp_key]); let prog = compile(&graph); @@ -471,19 +493,19 @@ fn compile_and_eval_chain() { #[test] fn compile_and_eval_reuse() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let y = builder.add_input("y".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(y)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(y)], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); let prog = compile(&graph); @@ -493,21 +515,25 @@ fn compile_and_eval_reuse() { #[test] fn compile_and_eval_multi_output() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let a = builder.add_input("a".to_string()); - let mul = builder.add_op( + let mul = builder.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, + ); + let exp = builder.add_operation( + ScalarOp::Exp, + vec![ValueRef::Local(mul[0])], + OperationRole::Primary, ); - let exp = builder.add_op(ScalarOp::Exp, vec![ValRef::Local(mul[0])], OpMode::Primal); let mul_key = builder.global_key(mul[0]).clone(); let exp_key = builder.global_key(exp[0]).clone(); builder.set_outputs(vec![mul[0], exp[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[mul_key, exp_key]); let prog = compile(&graph); @@ -520,21 +546,25 @@ fn compile_and_eval_multi_output() { // === End-to-end integration tests === #[test] -fn e2e_exp_ax_single_fragment() { - let mut builder = FragmentBuilder::::new(); +fn e2e_exp_ax_single_graph() { + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); let a = builder.add_input("a".to_string()); - let mul = builder.add_op( + let mul = builder.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, + ); + let exp = builder.add_operation( + ScalarOp::Exp, + vec![ValueRef::Local(mul[0])], + OperationRole::Primary, ); - let exp = builder.add_op(ScalarOp::Exp, vec![ValRef::Local(mul[0])], OpMode::Primal); let exp_key = builder.global_key(exp[0]).clone(); builder.set_outputs(vec![exp[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[exp_key]); let prog = compile(&graph); let result = prog.eval(&mut (), &[&2.0, &3.0]); @@ -543,25 +573,25 @@ fn e2e_exp_ax_single_fragment() { } #[test] -fn e2e_exp_ax_multi_fragment() { - let mut b0 = FragmentBuilder::::new(); +fn e2e_exp_ax_multi_graph() { + let mut b0 = GraphBuilder::::new(); let x = b0.add_input("x".to_string()); let a = b0.add_input("a".to_string()); - let mul = b0.add_op( + let mul = b0.add_operation( ScalarOp::Mul, - vec![ValRef::Local(x), ValRef::Local(a)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(a)], + OperationRole::Primary, ); let mul_key = b0.global_key(mul[0]).clone(); b0.set_outputs(vec![mul[0]]); let f0 = Arc::new(b0.build()); - let mut b1 = FragmentBuilder::::new(); + let mut b1 = GraphBuilder::::new(); b1.add_parent(f0.clone()); - let exp = b1.add_op( + let exp = b1.add_operation( ScalarOp::Exp, - vec![ValRef::External(mul_key)], - OpMode::Primal, + vec![ValueRef::External(mul_key)], + OperationRole::Primary, ); let exp_key = b1.global_key(exp[0]).clone(); b1.set_outputs(vec![exp[0]]); @@ -580,18 +610,18 @@ fn e2e_exp_ax_multi_fragment() { #[test] fn e2e_x_plus_x() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); - let sum = builder.add_op( + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(x), ValRef::Local(x)], - OpMode::Primal, + vec![ValueRef::Local(x), ValueRef::Local(x)], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); let prog = compile(&graph); @@ -601,15 +631,23 @@ fn e2e_x_plus_x() { #[test] fn e2e_neg_exp() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); - let exp = builder.add_op(ScalarOp::Exp, vec![ValRef::Local(x)], OpMode::Primal); - let neg = builder.add_op(ScalarOp::Neg, vec![ValRef::Local(exp[0])], OpMode::Primal); + let exp = builder.add_operation( + ScalarOp::Exp, + vec![ValueRef::Local(x)], + OperationRole::Primary, + ); + let neg = builder.add_operation( + ScalarOp::Neg, + vec![ValueRef::Local(exp[0])], + OperationRole::Primary, + ); let neg_key = builder.global_key(neg[0]).clone(); builder.set_outputs(vec![neg[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[neg_key]); let prog = compile(&graph); @@ -619,19 +657,23 @@ fn e2e_neg_exp() { #[test] fn e2e_dup_and_add() { - let mut builder = FragmentBuilder::::new(); + let mut builder = GraphBuilder::::new(); let x = builder.add_input("x".to_string()); - let dup = builder.add_op(ScalarOp::Dup, vec![ValRef::Local(x)], OpMode::Primal); - let sum = builder.add_op( + let dup = builder.add_operation( + ScalarOp::Dup, + vec![ValueRef::Local(x)], + OperationRole::Primary, + ); + let sum = builder.add_operation( ScalarOp::Add, - vec![ValRef::Local(dup[0]), ValRef::Local(dup[1])], - OpMode::Primal, + vec![ValueRef::Local(dup[0]), ValueRef::Local(dup[1])], + OperationRole::Primary, ); let sum_key = builder.global_key(sum[0]).clone(); builder.set_outputs(vec![sum[0]]); - let frag = Arc::new(builder.build()); + let graph = Arc::new(builder.build()); - let view = resolve(vec![frag]); + let view = resolve(vec![graph]); let graph = materialize_merge(&view, &[sum_key]); let prog = compile(&graph);