Skip to main content

prjunnamed_generic/
decision.rs

1//! Decision tree lowering.
2//!
3//! The `decision` pass implements decision tree lowering based on the well-known heuristic
4//! algorithm developed for ML from the paper "Tree pattern matching for ML" by Marianne Baudinet
5//! and David MacQueen (unpublished, 1985) the extended abstract of which is available from:
6//! *  <https://smlfamily.github.io/history/Baudinet-DM-tree-pat-match-12-85.pdf> (scan)
7//! *  <https://www.classes.cs.uchicago.edu/archive/2011/spring/22620-1/papers/macqueen-baudinet85.pdf> (OCR)
8//!
9//! The algorithm is described in ยง4 "Decision trees and the dispatching problem". Only two
10//! of the heuristics described apply here: the relevance heuristic and the branching factor
11//! heuristic.
12
13use std::fmt::Display;
14use std::collections::{BTreeMap, BTreeSet};
15use std::rc::Rc;
16
17use prjunnamed_netlist::{AssignCell, Cell, CellRef, Const, Design, MatchCell, Net, Trit, Value};
18
19/// Maps `pattern` (a constant where 0 and 1 match the respective states, and X matches any state)
20/// to a set of `rules` (the nets that are asserted if the `pattern` matches the value being tested).
21#[derive(Debug, Clone, PartialEq, Eq)]
22struct MatchRow {
23    pattern: Const,
24    rules: BTreeSet<Net>,
25}
26
27/// Matches `value` against ordered `rows` of patterns and their corresponding rules, where the `rules`
28/// for the first pattern that matches the value being tested are asserted, and all other `rules`
29/// are deasserted.
30///
31/// Invariant: once a matrix is constructed, there is always a case that matches.
32/// Note that due to the possiblity of inputs being `X`, this is only
33/// satisfiable with a catch-all row.
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct MatchMatrix {
36    value: Value,
37    rows: Vec<MatchRow>,
38}
39
40/// Describes the process of testing individual nets of a value (equivalently, eliminating columns
41/// from a match matrix) until a specific row is reached.
42#[derive(Debug, Clone, PartialEq, Eq)]
43enum Decision {
44    /// Drive a set of nets to 1.
45    Result { rules: BTreeSet<Net> },
46    /// Branch on the value of a particular net.
47    Branch { test: Net, if0: Box<Decision>, if1: Box<Decision> },
48}
49
50impl MatchRow {
51    fn new(pattern: impl Into<Const>, rules: impl IntoIterator<Item = Net>) -> Self {
52        Self { pattern: pattern.into(), rules: BTreeSet::from_iter(rules) }
53    }
54
55    fn empty(pattern_len: usize) -> Self {
56        Self::new(Const::undef(pattern_len), [])
57    }
58
59    fn merge(mut self, other: &MatchRow) -> Self {
60        self.pattern.extend(&other.pattern);
61        self.rules.extend(other.rules.iter().cloned());
62        self
63    }
64}
65
66impl MatchMatrix {
67    fn new(value: impl Into<Value>) -> Self {
68        Self { value: value.into(), rows: Vec::new() }
69    }
70
71    fn add(&mut self, row: MatchRow) -> usize {
72        assert_eq!(self.value.len(), row.pattern.len());
73        self.rows.push(row);
74        self.rows.len() - 1
75    }
76
77    fn add_enable(&mut self, enable: Net) {
78        if enable != Net::ONE {
79            for row in &mut self.rows {
80                row.pattern.push(Trit::One);
81            }
82            self.rows.insert(0, MatchRow::new(Const::undef(self.value.len()).concat(Trit::Zero), []));
83            self.value.push(enable);
84        }
85    }
86
87    /// Merge in a `MatchMatrix` describing a child `match` cell whose `enable`
88    /// input is being driven by `at`.
89    fn merge(mut self, at: Net, other: &MatchMatrix) -> Self {
90        self.value.extend(&other.value);
91        for self_row in std::mem::take(&mut self.rows) {
92            if self_row.rules.contains(&at) {
93                for other_row in &other.rows {
94                    self.add(self_row.clone().merge(other_row));
95                }
96            } else {
97                self.add(self_row.merge(&MatchRow::empty(other.value.len())));
98            }
99        }
100        self
101    }
102
103    fn iter_outputs(&self) -> impl Iterator<Item = Net> {
104        let mut outputs: Vec<Net> = self.rows.iter().flat_map(|row| row.rules.iter().copied()).collect();
105        outputs.sort();
106        outputs.dedup();
107        outputs.into_iter()
108    }
109
110    fn assume(mut self, target: Net, value: Trit) -> Self {
111        self.value =
112            Value::from_iter(self.value.into_iter().map(|net| if net == target { Net::from(value) } else { net }));
113        self
114    }
115
116    /// Remove redundant rows and columns.
117    ///
118    /// This function ensures the following normal-form properties:
119    /// - `self.value` does not contain any constant nets
120    /// - all nets occurs within `self.value` at most once
121    /// - a row of all `X` can only occur at the very end
122    /// - no two identical rows can occur immediately next to each other
123    ///
124    /// Note that the last of these properties is relatively weak, in that
125    /// stronger properties exist which can still be feasibly checked.
126    fn normalize(mut self) -> Self {
127        let mut remove_cols = BTreeSet::new();
128        let mut remove_rows = BTreeSet::new();
129
130        // Pick columns to remove where the matched value is constant or has repeated nets.
131
132        // For each `Net`, store the index of the first column being driven
133        // by that net.
134        let mut first_at = BTreeMap::new();
135        for (index, net) in self.value.iter().enumerate() {
136            if !net.is_const() && !first_at.contains_key(&net) {
137                first_at.insert(net, index);
138            } else {
139                remove_cols.insert(index);
140            }
141        }
142
143        // Pick rows to remove that:
144        // * are redundant with the immediately preceeding row, or
145        // * contradict themselves or the constant nets in matched value.
146        let mut prev_pattern = None;
147        'outer: for (row_index, row) in self.rows.iter_mut().enumerate() {
148            // Check if row will never match because of a previous row that:
149            // * has the same pattern, or
150            // * matches any value.
151            if let Some(ref prev_pattern) = prev_pattern
152                && (row.pattern == *prev_pattern || prev_pattern.is_undef())
153            {
154                remove_rows.insert(row_index);
155                continue;
156            }
157            prev_pattern = Some(row.pattern.clone());
158            // Check if row is contradictory.
159            for (net_index, net) in self.value.iter().enumerate() {
160                let mask = row.pattern[net_index];
161                // Row contradicts constant in matched value.
162                //
163                // Note that if we're matching against a constant `X`, removing
164                // the row is nevertheless valid, since the row isn't guaranteed
165                // to match regardless of the value of the `X`, and so we can
166                // refine it into "doesn't match".
167                match net.as_const() {
168                    Some(trit) if trit != mask && mask != Trit::Undef => {
169                        remove_rows.insert(row_index);
170                        continue 'outer;
171                    }
172                    _ => (),
173                }
174                // Check if the net appears multiple times in the matched value.
175                match first_at.get(&net) {
176                    // It doesn't.
177                    None => (),
178                    // It does and this is the first occurrence. Leave it alone.
179                    Some(&first_index) if first_index == net_index => (),
180                    // It does and this is the second or later occurrence. Check if it is compatible with
181                    // the first one. Also, if the first one was a don't-care, move this one into the position
182                    // of the first one.
183                    Some(&first_index) => {
184                        let first_mask = &mut row.pattern[first_index];
185                        if *first_mask != Trit::Undef && mask != Trit::Undef && *first_mask != mask {
186                            remove_rows.insert(row_index);
187                            continue 'outer;
188                        }
189                        if *first_mask == Trit::Undef {
190                            *first_mask = mask;
191                        }
192                    }
193                }
194            }
195        }
196
197        // Pick columns to remove where all of the patterns match any value.
198        let mut all_undef = vec![true; self.value.len()];
199        for (row_index, row) in self.rows.iter().enumerate() {
200            if remove_rows.contains(&row_index) {
201                continue;
202            }
203            for col_index in 0..self.value.len() {
204                if row.pattern[col_index] != Trit::Undef {
205                    all_undef[col_index] = false;
206                }
207            }
208        }
209        for (col_index, matches_any) in all_undef.into_iter().enumerate() {
210            if matches_any {
211                remove_cols.insert(col_index);
212            }
213        }
214
215        // Execute column and row removal.
216        fn remove_indices<'a, T>(
217            iter: impl IntoIterator<Item = T> + 'a,
218            remove_set: &'a BTreeSet<usize>,
219        ) -> impl Iterator<Item = T> + 'a {
220            iter.into_iter().enumerate().filter_map(|(index, elem)| (!remove_set.contains(&index)).then_some(elem))
221        }
222
223        self.value = Value::from_iter(remove_indices(self.value, &remove_cols));
224        self.rows = Vec::from_iter(remove_indices(self.rows, &remove_rows));
225        for row in &mut self.rows {
226            row.pattern = Const::from_iter(remove_indices(row.pattern.iter(), &remove_cols));
227        }
228        self
229    }
230
231    /// Construct a decision tree for the match matrix.
232    fn dispatch(mut self) -> Decision {
233        self = self.normalize();
234        if self.value.is_empty() || self.rows.len() == 1 {
235            Decision::Result { rules: self.rows.into_iter().next().map(|r| r.rules).unwrap_or_default() }
236        } else {
237            // Fanfiction of the heuristics from the 1986 paper that reduces them to: split the matrix on the column
238            // with the fewest don't-care's in it.
239            let mut undef_count = vec![0; self.value.len()];
240            for row in self.rows.iter() {
241                for (col_index, mask) in row.pattern.iter().enumerate() {
242                    if mask == Trit::Undef {
243                        undef_count[col_index] += 1;
244                    }
245                }
246            }
247            let test_index = (0..self.value.len()).min_by_key(|&i| undef_count[i]);
248            let test = self.value[test_index.unwrap()];
249
250            // Split the matrix into two, where the test net has a value of 0 and 1, and recurse.
251            let if0 = self.clone().assume(test, Trit::Zero).dispatch();
252            let if1 = self.assume(test, Trit::One).dispatch();
253            if if0 == if1 {
254                // Skip the branch if the outputs of the decision function are the same. This can happen
255                // e.g. in the following matrix:
256                //   00 => x
257                //   10 => x
258                //   XX => y
259                // regardless of the column selection order. This is readily apparent when the left-hand
260                // column is split on first, but even if the right-hand column is chosen, there will a
261                // case-split with both arms leading to `Decision::Result(x)`.
262                if0
263            } else {
264                Decision::Branch { test, if0: if0.into(), if1: if1.into() }
265            }
266        }
267    }
268}
269
270impl Decision {
271    /// Call `f` on the contents of each `Decision::Result` in this tree.
272    fn each_leaf(&self, f: &mut impl FnMut(&BTreeSet<Net>)) {
273        match self {
274            Decision::Result { rules } => f(rules),
275            Decision::Branch { if0, if1, .. } => {
276                if0.each_leaf(f);
277                if1.each_leaf(f);
278            }
279        }
280    }
281
282    /// Emit a mux-tree that outputs `values[x]` when net `x` would be `1`
283    /// according to the decision tree.
284    ///
285    /// Assumes that each case within `values` is mutually exclusive. Panics if
286    /// that is not the case.
287    fn emit_disjoint_mux(&self, design: &Design, values: &BTreeMap<Net, Value>, default: &Value) -> Value {
288        match self {
289            Decision::Result { rules } => {
290                let mut result = None;
291                for rule in rules {
292                    if let Some(value) = values.get(rule) {
293                        assert!(result.is_none());
294                        result = Some(value.clone());
295                    }
296                }
297                result.unwrap_or(default.clone())
298            }
299            Decision::Branch { test, if0, if1 } => design.add_mux(
300                *test,
301                if1.emit_disjoint_mux(design, values, default),
302                if0.emit_disjoint_mux(design, values, default),
303            ),
304        }
305    }
306
307    /// Emit a mux-tree that drives the `nets` according to the decision tree.
308    fn emit_one_hot_mux(&self, design: &Design, nets: &Value) -> Value {
309        match self {
310            Decision::Result { rules } => Value::from_iter(
311                nets.iter().map(|net| if rules.contains(&net) { Trit::One } else { Trit::Zero }.into()),
312            ),
313            Decision::Branch { test, if0, if1 } => {
314                design.add_mux(*test, if1.emit_one_hot_mux(design, nets), if0.emit_one_hot_mux(design, nets))
315            }
316        }
317    }
318}
319
320struct MatchTrees<'a> {
321    design: &'a Design,
322    /// Set of all `match` cells that aren't children. A `match` cell is a child
323    /// of another `match` cell if its `enable` input is being driven from
324    /// the output of the parent, and it is the unique such `match` cell for
325    /// that particular output bit. That is, if it is possible to merge the
326    /// child into the same decision tree.
327    roots: BTreeSet<CellRef<'a>>,
328    /// Maps a particular output of a `match` cell to the child `match` cell
329    /// whose `enable` input it is driving.
330    subtrees: BTreeMap<(CellRef<'a>, usize), CellRef<'a>>,
331}
332
333impl<'a> MatchTrees<'a> {
334    /// Recognize a tree of `match` cells, connected by their enable inputs.
335    fn build(design: &'a Design) -> MatchTrees<'a> {
336        let mut roots: BTreeSet<CellRef> = BTreeSet::new();
337        let mut subtrees: BTreeMap<(CellRef, usize), BTreeSet<CellRef>> = BTreeMap::new();
338        for cell_ref in design.iter_cells() {
339            let Cell::Match(MatchCell { enable, .. }) = &*cell_ref.get() else { continue };
340            let (enable_cell_ref, offset) = design.find_cell(*enable);
341            if let Cell::Match(_) = &*enable_cell_ref.get() {
342                // Driven by a match cell; may be a subtree or a root depending on its fanout.
343                subtrees.entry((enable_cell_ref, offset)).or_default().insert(cell_ref);
344                continue;
345            }
346            // Driven by some other cell or a constant; is a root.
347            roots.insert(cell_ref);
348        }
349
350        // Whenever multiple subtrees are connected to the same one-hot output, it is not possible
351        // to merge all of them into the same matrix. Turn all of these subtrees into roots.
352        let subtrees = subtrees
353            .into_iter()
354            .filter_map(|(key, subtrees)| {
355                if subtrees.len() == 1 {
356                    Some((key, subtrees.into_iter().next().unwrap()))
357                } else {
358                    roots.extend(subtrees);
359                    None
360                }
361            })
362            .collect();
363
364        Self { design, roots, subtrees }
365    }
366
367    /// Convert a tree of `match` cells into a matrix.
368    ///
369    /// Collects a list of all the cells being lifted into the matrix into
370    /// `all_cell_refs`.
371    ///
372    /// Replaces outputs that don't have any patterns at all with `Net::ZERO`,
373    /// but otherwise doesn't modify the design.
374    fn cell_into_matrix(&self, cell_ref: CellRef<'a>, all_cell_refs: &mut Vec<CellRef<'a>>) -> MatchMatrix {
375        let Cell::Match(match_cell) = &*cell_ref.get() else { unreachable!() };
376        let output = cell_ref.output();
377        all_cell_refs.push(cell_ref);
378
379        // Create matrix for this cell.
380        let mut matrix = MatchMatrix::new(&match_cell.value);
381        for (output_net, alternates) in output.iter().zip(match_cell.patterns.iter()) {
382            for pattern in alternates {
383                matrix.add(MatchRow::new(pattern.clone(), [output_net]));
384            }
385            if alternates.is_empty() {
386                self.design.replace_net(output_net, Net::ZERO);
387            }
388        }
389        matrix.add(MatchRow::empty(match_cell.value.len()));
390
391        // Create matrices for subtrees and merge them into the matrix for this cell.
392        for (offset, output_net) in output.iter().enumerate() {
393            if let Some(&sub_cell_ref) = self.subtrees.get(&(cell_ref, offset)) {
394                matrix = matrix.merge(output_net, &self.cell_into_matrix(sub_cell_ref, all_cell_refs));
395            }
396        }
397
398        matrix
399    }
400
401    /// For each tree of `match` cells, return a corresponding `MatchMatrix`
402    /// and a list of `match` cells that this matrix implements.
403    fn iter_matrices<'b>(&'b self) -> impl Iterator<Item = (MatchMatrix, Vec<CellRef<'b>>)> + 'b {
404        self.roots.iter().map(|&cell_ref| {
405            let Cell::Match(MatchCell { enable, .. }) = &*cell_ref.get() else { unreachable!() };
406            let mut all_cell_refs = Vec::new();
407            let mut matrix = self.cell_into_matrix(cell_ref, &mut all_cell_refs);
408            matrix.add_enable(*enable);
409            (matrix, all_cell_refs)
410        })
411    }
412}
413
414struct AssignChains<'a> {
415    chains: Vec<Vec<CellRef<'a>>>,
416}
417
418impl<'a> AssignChains<'a> {
419    fn build(design: &'a Design) -> AssignChains<'a> {
420        let mut roots: BTreeSet<CellRef> = BTreeSet::new();
421        let mut links: BTreeMap<CellRef, BTreeSet<CellRef>> = BTreeMap::new();
422        for cell_ref in design.iter_cells() {
423            let Cell::Assign(AssignCell { value, offset: 0, update, .. }) = &*cell_ref.get() else { continue };
424            if update.len() != value.len() {
425                continue;
426            }
427            let (value_cell_ref, _offset) = design.find_cell(value[0]);
428            if value_cell_ref.output() == *value
429                && let Cell::Assign(_) = &*value_cell_ref.get()
430            {
431                links.entry(value_cell_ref).or_default().insert(cell_ref);
432                continue;
433            }
434            roots.insert(cell_ref);
435        }
436
437        let mut chains = Vec::new();
438        for root in roots {
439            let mut chain = vec![root];
440            while let Some(links) = links.get(chain.last().unwrap()) {
441                if links.len() == 1 {
442                    chain.push(*links.first().unwrap());
443                } else {
444                    break;
445                }
446            }
447            if chain.len() > 1 {
448                chains.push(chain);
449            }
450        }
451
452        Self { chains }
453    }
454
455    fn iter_disjoint<'b>(
456        &'b self,
457        decisions: &'a BTreeMap<Net, Rc<Decision>>,
458        occurrences: &BTreeMap<Net, Vec<u32>>,
459    ) -> impl Iterator<Item = (Rc<Decision>, &'b [CellRef<'a>])> {
460        fn enable_of(cell_ref: CellRef) -> Net {
461            let Cell::Assign(AssignCell { enable, .. }) = &*cell_ref.get() else { unreachable!() };
462            *enable
463        }
464
465        self.chains.iter().filter_map(|chain| {
466            let mut used_branches = BTreeSet::new();
467            // Add all branches driving `net` to `used_branches`. Returns
468            // `false` if this is a conflict (i.e. the nets aren't mutually
469            // exclusive).
470            let mut consume_branches = |net: Net| -> bool {
471                let Some(occurs) = occurrences.get(&net) else {
472                    // net is not driven by any branches in this decision tree.
473                    // this can happen if a pattern turns out to be impossible
474                    // (e.g. due to constant propagation)
475                    return true;
476                };
477
478                for &occurrence in occurs {
479                    if !used_branches.insert(occurrence) {
480                        return false;
481                    }
482                }
483
484                true
485            };
486
487            // Check if the enables belong to disjoint branches within the same decision tree
488            // (like in a SystemVerilog "unique" or "unique0" statement).
489            let enable = enable_of(chain[0]);
490            let decision = decisions.get(&enable)?;
491            assert!(consume_branches(enable));
492            let mut end_index = chain.len();
493            'chain: for (index, &other_cell) in chain.iter().enumerate().skip(1) {
494                let enable = enable_of(other_cell);
495                let other_decision = decisions.get(&enable)?;
496                if !Rc::ptr_eq(decision, other_decision) || !consume_branches(enable) {
497                    end_index = index;
498                    break 'chain;
499                }
500            }
501            let chain = &chain[..end_index];
502
503            Some((decision.clone(), chain))
504        })
505    }
506}
507
508pub fn decision(design: &mut Design) {
509    // Detect and extract trees of `match` cells present in the netlist.
510    let match_trees = MatchTrees::build(design);
511
512    // Detect and extract chains of `assign` statements without slicing.
513    let assign_chains = AssignChains::build(design);
514
515    // Combine each tree of `match` cells into a single match matrix.
516    // Then build a decision tree for it and use it to drive the output.
517    let mut decisions: BTreeMap<Net, Rc<Decision>> = BTreeMap::new();
518
519    let mut next_branch: u32 = 0;
520    let mut occurrences: BTreeMap<Net, Vec<u32>> = BTreeMap::new();
521
522    for (matrix, matches) in match_trees.iter_matrices() {
523        let all_outputs = BTreeSet::from_iter(matrix.iter_outputs());
524        if cfg!(feature = "trace") {
525            eprint!(">matrix:\n{matrix}");
526        }
527
528        let decision = Rc::new(matrix.dispatch());
529        if cfg!(feature = "trace") {
530            eprint!(">decision tree:\n{decision}")
531        }
532
533        decision.each_leaf(&mut |outputs| {
534            let branch = next_branch;
535            next_branch += 1;
536
537            for &output in outputs {
538                occurrences.entry(output).or_default().push(branch);
539            }
540        });
541
542        for &output in &all_outputs {
543            decisions.insert(output, decision.clone());
544        }
545
546        let _guard = design.use_metadata_from(&matches[..]);
547        let nets = Value::from_iter(all_outputs);
548        design.replace_value(&nets, decision.emit_one_hot_mux(design, &nets));
549    }
550
551    // Find chains of `assign` cells that are order-independent.
552    // Then lower these cells to a `mux` tree without `eq` cells.
553    let mut used_assigns = BTreeSet::new();
554    for (decision, chain) in assign_chains.iter_disjoint(&decisions, &occurrences) {
555        let (first_assign, last_assign) = (chain.first().unwrap(), chain.last().unwrap());
556        if cfg!(feature = "trace") {
557            eprintln!(">disjoint:");
558            for &cell_ref in chain {
559                eprintln!("{}", design.display_cell(cell_ref));
560            }
561        }
562
563        let mut values = BTreeMap::new();
564        let Cell::Assign(AssignCell { value: default, .. }) = &*first_assign.get() else { unreachable!() };
565        for &cell_ref in chain {
566            let Cell::Assign(AssignCell { enable, update, .. }) = &*cell_ref.get() else { unreachable!() };
567            values.insert(*enable, update.clone());
568        }
569
570        let _guard = design.use_metadata_from(chain);
571        design.replace_value(last_assign.output(), decision.emit_disjoint_mux(design, &values, default));
572        used_assigns.insert(*last_assign);
573    }
574
575    // Lower other `assign` cells.
576    for cell_ref in design.iter_cells().filter(|cell_ref| !used_assigns.contains(cell_ref)) {
577        let Cell::Assign(assign_cell) = &*cell_ref.get() else { continue };
578        if cfg!(feature = "trace") {
579            eprintln!(">chained: {}", design.display_cell(cell_ref));
580        }
581
582        let _guard = design.use_metadata_from(&[cell_ref]);
583        let mut nets = Vec::from_iter(assign_cell.value.iter());
584        let slice = assign_cell.offset..(assign_cell.offset + assign_cell.update.len());
585        nets[slice.clone()].copy_from_slice(
586            &design.add_mux(assign_cell.enable, &assign_cell.update, assign_cell.value.slice(slice))[..],
587        );
588        design.replace_value(cell_ref.output(), Value::from(nets));
589    }
590
591    design.compact();
592}
593
594impl Display for MatchRow {
595    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
596        for (index, trit) in self.pattern.iter().rev().enumerate() {
597            if index != 0 && index.is_multiple_of(4) {
598                write!(f, "_")?;
599            }
600            write!(f, "{trit}")?;
601        }
602        write!(f, " =>")?;
603        if self.rules.is_empty() {
604            return write!(f, " (empty)");
605        }
606        for rule in &self.rules {
607            write!(f, " {rule}")?;
608        }
609        Ok(())
610    }
611}
612
613impl Display for MatchMatrix {
614    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
615        writeln!(f, "{}:", self.value)?;
616        for row in &self.rows {
617            writeln!(f, "  {row}")?;
618        }
619        Ok(())
620    }
621}
622
623impl Decision {
624    fn format(&self, f: &mut std::fmt::Formatter, level: usize) -> std::fmt::Result {
625        let format_rules = |f: &mut std::fmt::Formatter, rules: &BTreeSet<Net>| {
626            if rules.is_empty() {
627                write!(f, " (empty)")
628            } else {
629                for rule in rules {
630                    write!(f, " {rule}")?;
631                }
632                Ok(())
633            }
634        };
635
636        let format_decision = |f: &mut std::fmt::Formatter, net: Net, value: usize, decision: &Decision| {
637            if let Decision::Result { rules } = decision
638                && rules.is_empty()
639            {
640                return Ok(());
641            }
642            for _ in 0..level {
643                write!(f, "  ")?;
644            }
645            match decision {
646                Decision::Result { rules } => {
647                    write!(f, "{net} = {value} =>")?;
648                    format_rules(f, rules)?;
649                    writeln!(f)
650                }
651                Decision::Branch { .. } => {
652                    writeln!(f, "{net} = {value} =>")?;
653                    decision.format(f, level + 1)
654                }
655            }
656        };
657
658        match self {
659            Decision::Result { rules } => {
660                assert_eq!(level, 0);
661                write!(f, "=>")?;
662                format_rules(f, rules)?;
663                writeln!(f)?;
664            }
665            Decision::Branch { test, if0, if1 } => {
666                format_decision(f, *test, 0, if0)?;
667                format_decision(f, *test, 1, if1)?;
668            }
669        }
670        Ok(())
671    }
672}
673
674impl Display for Decision {
675    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
676        self.format(f, 0)
677    }
678}
679
680#[cfg(test)]
681mod test {
682    #![allow(non_snake_case)]
683
684    use std::collections::{BTreeMap, BTreeSet};
685
686    use prjunnamed_netlist::{assert_isomorphic, AssignCell, Cell, Const, Design, MatchCell, Net, Value};
687
688    use super::{decision, AssignChains, Decision, MatchMatrix, MatchRow, MatchTrees};
689
690    struct Helper(Design);
691
692    impl Helper {
693        fn new() -> Self {
694            Self(Design::new())
695        }
696
697        fn val(&self, width: usize) -> Value {
698            self.0.add_void(width)
699        }
700
701        fn net(&self) -> Net {
702            self.0.add_void(1).unwrap_net()
703        }
704
705        fn rs(&self, rule: Net) -> Box<Decision> {
706            Box::new(Decision::Result { rules: BTreeSet::from_iter([rule]) })
707        }
708
709        fn br(&self, test: Net, if1: Box<Decision>, if0: Box<Decision>) -> Box<Decision> {
710            Box::new(Decision::Branch { test, if0, if1 })
711        }
712    }
713
714    #[test]
715    fn test_add_enable() {
716        let h = Helper::new();
717
718        let v = h.val(2);
719        let (n1, n2, en) = (h.net(), h.net(), h.net());
720
721        let mut ml = MatchMatrix::new(&v);
722        ml.add(MatchRow::new(Const::lit("10"), [n1]));
723        ml.add(MatchRow::new(Const::lit("01"), [n2]));
724
725        let mut mr = MatchMatrix::new(v.concat(en));
726        mr.add(MatchRow::new(Const::lit("0XX"), []));
727        mr.add(MatchRow::new(Const::lit("110"), [n1]));
728        mr.add(MatchRow::new(Const::lit("101"), [n2]));
729
730        ml.add_enable(en);
731        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
732    }
733
734    #[test]
735    fn test_add_enable_trivial() {
736        let h = Helper::new();
737
738        let v = h.val(2);
739        let (n1, n2) = (h.net(), h.net());
740
741        let mut ml = MatchMatrix::new(&v);
742        ml.add(MatchRow::new(Const::lit("10"), [n1]));
743        ml.add(MatchRow::new(Const::lit("01"), [n2]));
744
745        let mr = ml.clone();
746
747        ml.add_enable(Net::ONE);
748        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
749    }
750
751    #[test]
752    fn test_merge_1() {
753        let h = Helper::new();
754
755        let v1 = h.val(2);
756        let (n11, n12) = (h.net(), h.net());
757        let v2 = h.val(3);
758        let (n21, n22) = (h.net(), h.net());
759        let mut m1 = MatchMatrix::new(&v1);
760        m1.add(MatchRow::new(Const::lit("10"), [n11]));
761        m1.add(MatchRow::new(Const::lit("01"), [n12]));
762        m1.add(MatchRow::new(Const::lit("XX"), []));
763
764        let mut m2 = MatchMatrix::new(&v2);
765        m2.add(MatchRow::new(Const::lit("X00"), [n21]));
766        m2.add(MatchRow::new(Const::lit("10X"), [n22]));
767        m2.add(MatchRow::new(Const::lit("XXX"), []));
768
769        let ml1 = m1.clone().merge(n12, &m2);
770
771        let mut mr1 = MatchMatrix::new(v1.concat(&v2));
772        mr1.add(MatchRow::new(Const::lit("XXX10"), [n11]));
773        mr1.add(MatchRow::new(Const::lit("X0001"), [n12, n21]));
774        mr1.add(MatchRow::new(Const::lit("10X01"), [n12, n22]));
775        mr1.add(MatchRow::new(Const::lit("XXX01"), [n12]));
776        mr1.add(MatchRow::new(Const::lit("XXXXX"), []));
777
778        assert_eq!(ml1, mr1, "\n{ml1} != \n{mr1}");
779    }
780
781    #[test]
782    fn test_merge_2() {
783        let h = Helper::new();
784
785        let v1 = h.val(2);
786        let (n11, n12) = (h.net(), h.net());
787        let v2 = h.val(3);
788        let (n21, n22) = (h.net(), h.net());
789        let mut m1 = MatchMatrix::new(&v1);
790        m1.add(MatchRow::new(Const::lit("10"), [n11]));
791        m1.add(MatchRow::new(Const::lit("01"), [n11]));
792        m1.add(MatchRow::new(Const::lit("XX"), [n12]));
793
794        let mut m2 = MatchMatrix::new(&v2);
795        m2.add(MatchRow::new(Const::lit("X00"), [n21]));
796        m2.add(MatchRow::new(Const::lit("10X"), [n22]));
797        m2.add(MatchRow::new(Const::lit("XXX"), []));
798
799        let ml1 = m1.clone().merge(n11, &m2);
800
801        let mut mr1 = MatchMatrix::new(v1.concat(&v2));
802        mr1.add(MatchRow::new(Const::lit("X0010"), [n11, n21]));
803        mr1.add(MatchRow::new(Const::lit("10X10"), [n11, n22]));
804        mr1.add(MatchRow::new(Const::lit("XXX10"), [n11]));
805        mr1.add(MatchRow::new(Const::lit("X0001"), [n11, n21]));
806        mr1.add(MatchRow::new(Const::lit("10X01"), [n11, n22]));
807        mr1.add(MatchRow::new(Const::lit("XXX01"), [n11]));
808        mr1.add(MatchRow::new(Const::lit("XXXXX"), [n12]));
809
810        assert_eq!(ml1, mr1, "\n{ml1} != \n{mr1}");
811    }
812
813    #[test]
814    fn test_normalize_vertical() {
815        let h = Helper::new();
816        let n = h.net();
817
818        let mut m00 = MatchMatrix::new(Value::from(Const::lit("0")));
819        m00.add(MatchRow::new(Const::lit("0"), [n]));
820
821        let mut m01 = MatchMatrix::new(Value::from(Const::lit("0")));
822        m01.add(MatchRow::new(Const::lit("1"), [n]));
823
824        let mut m0X = MatchMatrix::new(Value::from(Const::lit("0")));
825        m0X.add(MatchRow::new(Const::lit("X"), [n]));
826
827        let mut m10 = MatchMatrix::new(Value::from(Const::lit("1")));
828        m10.add(MatchRow::new(Const::lit("0"), [n]));
829
830        let mut m11 = MatchMatrix::new(Value::from(Const::lit("1")));
831        m11.add(MatchRow::new(Const::lit("1"), [n]));
832
833        let mut m1X = MatchMatrix::new(Value::from(Const::lit("1")));
834        m1X.add(MatchRow::new(Const::lit("X"), [n]));
835
836        let mut mX0 = MatchMatrix::new(Value::from(Const::lit("X")));
837        mX0.add(MatchRow::new(Const::lit("0"), [n]));
838
839        let mut mX1 = MatchMatrix::new(Value::from(Const::lit("X")));
840        mX1.add(MatchRow::new(Const::lit("1"), [n]));
841
842        let mut mXX = MatchMatrix::new(Value::from(Const::lit("X")));
843        mXX.add(MatchRow::new(Const::lit("X"), [n]));
844
845        for must_reject in [m01, m10, mX0, mX1] {
846            let normalized = must_reject.clone().normalize();
847            assert_eq!(normalized.rows.len(), 0, "before:\n{must_reject}\nafter:\n{normalized}");
848        }
849        for must_accept in [m00, m0X, m11, m1X, mXX] {
850            let normalized = must_accept.clone().normalize();
851            assert_eq!(normalized.rows.len(), 1, "before:\n{must_accept}\nafter:\n{normalized}");
852            assert_eq!(normalized.rows[0].pattern.len(), 0, "before:\n{must_accept}\nafter:\n{normalized}");
853        }
854    }
855
856    #[test]
857    fn test_normalize_horizontal() {
858        let h = Helper::new();
859        let v = h.val(1);
860        let n = h.net();
861
862        let mut m1 = MatchMatrix::new(v.concat(&v));
863        m1.add(MatchRow::new(Const::lit("0X"), [n]));
864        m1 = m1.normalize();
865        assert_eq!(m1.rows[0].pattern, Const::lit("0"));
866
867        let mut m2 = MatchMatrix::new(v.concat(&v));
868        m2.add(MatchRow::new(Const::lit("X0"), [n]));
869        m2 = m2.normalize();
870        assert_eq!(m2.rows[0].pattern, Const::lit("0"));
871
872        let mut m3 = MatchMatrix::new(v.concat(&v));
873        m3.add(MatchRow::new(Const::lit("10"), [n]));
874        m3 = m3.normalize();
875        assert_eq!(m3.rows.len(), 0);
876    }
877
878    #[test]
879    fn test_normalize_duplicate_row() {
880        let h = Helper::new();
881        let v = h.val(2);
882        let (n1, n2) = (h.net(), h.net());
883
884        let mut m = MatchMatrix::new(v);
885        m.add(MatchRow::new(Const::lit("10"), [n1]));
886        m.add(MatchRow::new(Const::lit("10"), [n2]));
887        m = m.normalize();
888        assert_eq!(m.rows.len(), 1);
889        assert_eq!(m.rows[0].pattern, Const::lit("10"));
890        assert_eq!(m.rows[0].rules, BTreeSet::from_iter([n1]));
891    }
892
893    #[test]
894    fn test_normalize_irrefitable() {
895        let h = Helper::new();
896        let v = h.val(2);
897        let (n1, n2) = (h.net(), h.net());
898
899        let mut m = MatchMatrix::new(v);
900        m.add(MatchRow::new(Const::lit("XX"), [n1]));
901        m.add(MatchRow::new(Const::lit("10"), [n2]));
902        m = m.normalize();
903        assert_eq!(m.rows.len(), 1);
904        assert_eq!(m.rows[0].pattern, Const::lit(""));
905        assert_eq!(m.rows[0].rules, BTreeSet::from_iter([n1]));
906    }
907
908    #[test]
909    fn test_normalize_unused_column() {
910        let h = Helper::new();
911        let v = h.val(2);
912        let (n1, n2) = (h.net(), h.net());
913
914        let mut m = MatchMatrix::new(&v);
915        m.add(MatchRow::new(Const::lit("X0"), [n1]));
916        m.add(MatchRow::new(Const::lit("X1"), [n2]));
917        m = m.normalize();
918        assert_eq!(m.value, v.slice(0..1));
919        assert_eq!(m.rows.len(), 2);
920        assert_eq!(m.rows[0], MatchRow::new(Const::lit("0"), [n1]));
921        assert_eq!(m.rows[1], MatchRow::new(Const::lit("1"), [n2]));
922    }
923
924    #[test]
925    fn test_normalize_unused_column_after_elim() {
926        let h = Helper::new();
927        let v = h.val(2);
928        let (n1, n2, n3) = (h.net(), h.net(), h.net());
929
930        let mut m = MatchMatrix::new(v.concat(&v));
931        m.add(MatchRow::new(Const::lit("XXX0"), [n1]));
932        m.add(MatchRow::new(Const::lit("XXX1"), [n2]));
933        m.add(MatchRow::new(Const::lit("1X0X"), [n3]));
934        m = m.normalize();
935        assert_eq!(m.value, v.slice(0..1));
936        assert_eq!(m.rows.len(), 2);
937        assert_eq!(m.rows[0], MatchRow::new(Const::lit("0"), [n1]));
938        assert_eq!(m.rows[1], MatchRow::new(Const::lit("1"), [n2]));
939    }
940
941    macro_rules! assert_dispatch {
942        ($m:expr, $d:expr) => {
943            let dl = $m.clone().dispatch();
944            let dr = $d;
945            assert!(dl == *dr, "\ndispatching {}\n{} != \n{}", $m, dl, dr);
946        };
947    }
948
949    #[test]
950    fn test_decide_0() {
951        let h = Helper::new();
952
953        let v = h.val(1);
954        let n = h.net();
955        let mut m = MatchMatrix::new(&v);
956        m.add(MatchRow::new(Const::lit("0"), [n]));
957
958        let d = h.rs(n);
959
960        assert_dispatch!(m, d);
961    }
962
963    #[test]
964    fn test_decide_0_1() {
965        let h = Helper::new();
966
967        let v = h.val(1);
968        let (n1, n2) = (h.net(), h.net());
969        let mut m = MatchMatrix::new(&v);
970        m.add(MatchRow::new(Const::lit("0"), [n1]));
971        m.add(MatchRow::new(Const::lit("1"), [n2]));
972
973        let d = h.br(v[0], h.rs(n2), h.rs(n1));
974
975        assert_dispatch!(m, d);
976    }
977
978    #[test]
979    fn test_decide_0_X() {
980        let h = Helper::new();
981
982        let v = h.val(1);
983        let (n1, n2) = (h.net(), h.net());
984        let mut m = MatchMatrix::new(&v);
985        m.add(MatchRow::new(Const::lit("0"), [n1]));
986        m.add(MatchRow::new(Const::lit("X"), [n2]));
987
988        let d = h.br(v[0], h.rs(n2), h.rs(n1));
989
990        assert_dispatch!(m, d);
991    }
992
993    #[test]
994    fn test_decide_1_X() {
995        let h = Helper::new();
996
997        let v = h.val(1);
998        let (n1, n2) = (h.net(), h.net());
999        let mut m = MatchMatrix::new(&v);
1000        m.add(MatchRow::new(Const::lit("1"), [n1]));
1001        m.add(MatchRow::new(Const::lit("X"), [n2]));
1002
1003        let d = h.br(v[0], h.rs(n1), h.rs(n2));
1004
1005        assert_dispatch!(m, d);
1006    }
1007
1008    #[test]
1009    fn test_decide_X_0_1() {
1010        let h = Helper::new();
1011
1012        let v = h.val(1);
1013        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1014        let mut m = MatchMatrix::new(&v);
1015        m.add(MatchRow::new(Const::lit("X"), [n1]));
1016        m.add(MatchRow::new(Const::lit("0"), [n2]));
1017        m.add(MatchRow::new(Const::lit("1"), [n3]));
1018
1019        let d = h.rs(n1);
1020
1021        assert_dispatch!(m, d);
1022    }
1023
1024    #[test]
1025    fn test_decide_0_1_X() {
1026        let h = Helper::new();
1027
1028        let v = h.val(1);
1029        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1030        let mut m = MatchMatrix::new(&v);
1031        m.add(MatchRow::new(Const::lit("0"), [n1]));
1032        m.add(MatchRow::new(Const::lit("1"), [n2]));
1033        m.add(MatchRow::new(Const::lit("X"), [n3]));
1034
1035        let d = h.br(v[0], h.rs(n2), h.rs(n1));
1036
1037        assert_dispatch!(m, d);
1038    }
1039
1040    #[test]
1041    fn test_decide_0X_1X_XX() {
1042        let h = Helper::new();
1043
1044        let v = h.val(2);
1045        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1046        let mut m = MatchMatrix::new(&v);
1047        m.add(MatchRow::new(Const::lit("X0"), [n1]));
1048        m.add(MatchRow::new(Const::lit("X1"), [n2]));
1049        m.add(MatchRow::new(Const::lit("XX"), [n3]));
1050
1051        let d = h.br(v[0], h.rs(n2), h.rs(n1));
1052
1053        assert_dispatch!(m, d);
1054    }
1055
1056    #[test]
1057    fn test_decide_0X_11_XX() {
1058        let h = Helper::new();
1059
1060        let v = h.val(2);
1061        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1062        let mut m = MatchMatrix::new(&v);
1063        m.add(MatchRow::new(Const::lit("0X"), [n1]));
1064        m.add(MatchRow::new(Const::lit("11"), [n2]));
1065        m.add(MatchRow::new(Const::lit("XX"), [n3]));
1066
1067        let d = h.br(v[1], h.br(v[0], h.rs(n2), h.rs(n3)), h.rs(n1));
1068
1069        assert_dispatch!(m, d);
1070    }
1071
1072    #[test]
1073    fn test_decide_00_10_XX() {
1074        let h = Helper::new();
1075
1076        let v = h.val(2);
1077        let (n1, n2) = (h.net(), h.net());
1078        let mut m = MatchMatrix::new(&v);
1079        m.add(MatchRow::new(Const::lit("00"), [n1]));
1080        m.add(MatchRow::new(Const::lit("01"), [n1]));
1081        m.add(MatchRow::new(Const::lit("XX"), [n2]));
1082
1083        let d = h.br(v[1], h.rs(n2), h.rs(n1));
1084
1085        assert_dispatch!(m, d);
1086    }
1087
1088    #[test]
1089    fn test_match_tree_build_root_1() {
1090        let mut design = Design::new();
1091        let a = design.add_input("a", 1);
1092        let y = design.add_match(MatchCell { value: a, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1093        design.apply();
1094
1095        let y_cell = design.find_cell(y[0]).0;
1096
1097        let match_trees = MatchTrees::build(&design);
1098        assert!(match_trees.roots == BTreeSet::from_iter([y_cell]));
1099        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1100    }
1101
1102    #[test]
1103    fn test_match_tree_build_root_pi() {
1104        let mut design = Design::new();
1105        let a = design.add_input("a", 1);
1106        let b = design.add_input1("b");
1107        let y = design.add_match(MatchCell { value: a, enable: b, patterns: vec![vec![Const::lit("0")]] });
1108        design.apply();
1109
1110        let y_cell = design.find_cell(y[0]).0;
1111
1112        let match_trees = MatchTrees::build(&design);
1113        assert!(match_trees.roots == BTreeSet::from_iter([y_cell]));
1114        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1115    }
1116
1117    #[test]
1118    fn test_match_tree_build_root_subtree() {
1119        let mut design = Design::new();
1120        let a = design.add_input("a", 1);
1121        let y1 =
1122            design.add_match(MatchCell { value: a.clone(), enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1123        let y2 = design.add_match(MatchCell { value: a, enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1124        design.apply();
1125
1126        let y1_cell = design.find_cell(y1[0]).0;
1127        let y2_cell = design.find_cell(y2[0]).0;
1128
1129        let match_trees = MatchTrees::build(&design);
1130        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell]));
1131        assert!(match_trees.subtrees == BTreeMap::from_iter([((y1_cell, 0), y2_cell)]));
1132    }
1133
1134    #[test]
1135    fn test_match_tree_build_root_subtrees_disjoint() {
1136        let mut design = Design::new();
1137        let a = design.add_input("a", 1);
1138        let y1 = design.add_match(MatchCell {
1139            value: a.clone(),
1140            enable: Net::ONE,
1141            patterns: vec![vec![Const::lit("0")], vec![Const::lit("1")]],
1142        });
1143        let y2 = design.add_match(MatchCell { value: a.clone(), enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1144        let y3 = design.add_match(MatchCell { value: a, enable: y1[1], patterns: vec![vec![Const::lit("0")]] });
1145        design.apply();
1146
1147        let y1_cell = design.find_cell(y1[0]).0;
1148        let y2_cell = design.find_cell(y2[0]).0;
1149        let y3_cell = design.find_cell(y3[0]).0;
1150
1151        let match_trees = MatchTrees::build(&design);
1152        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell]));
1153        assert!(match_trees.subtrees == BTreeMap::from_iter([((y1_cell, 0), y2_cell), ((y1_cell, 1), y3_cell)]));
1154    }
1155
1156    #[test]
1157    fn test_match_tree_build_root_subtrees_rerooted() {
1158        let mut design = Design::new();
1159        let a = design.add_input("a", 1);
1160        let y1 =
1161            design.add_match(MatchCell { value: a.clone(), enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1162        let y2 = design.add_match(MatchCell { value: a.clone(), enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1163        let y3 = design.add_match(MatchCell { value: a, enable: y1[0], patterns: vec![vec![Const::lit("1")]] });
1164        design.apply();
1165
1166        let y1_cell = design.find_cell(y1[0]).0;
1167        let y2_cell = design.find_cell(y2[0]).0;
1168        let y3_cell = design.find_cell(y3[0]).0;
1169
1170        let match_trees = MatchTrees::build(&design);
1171        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell, y2_cell, y3_cell]));
1172        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1173    }
1174
1175    #[test]
1176    fn test_match_cell_into_matrix_flat() {
1177        let mut design = Design::new();
1178        let a = design.add_input("a", 3);
1179        let y = design.add_match(MatchCell {
1180            value: a.clone(),
1181            enable: Net::ONE,
1182            patterns: vec![vec![Const::lit("000"), Const::lit("111")], vec![Const::lit("010")]],
1183        });
1184        let yy = design.add_buf(&y);
1185        design.add_output("y", &yy);
1186        design.apply();
1187
1188        let y_cell = design.find_cell(y[0]).0;
1189        let mut match_cells = Vec::new();
1190        let m = MatchTrees::build(&design).cell_into_matrix(y_cell, &mut match_cells);
1191        assert_eq!(match_cells.len(), 1);
1192        assert_eq!(match_cells[0].output(), y);
1193        design.apply();
1194
1195        let yy_cell = design.find_cell(yy[0]).0;
1196        let Cell::Buf(y) = &*yy_cell.get() else { unreachable!() };
1197        assert_eq!(m.value, a);
1198        assert_eq!(m.rows, vec![
1199            MatchRow::new(Const::lit("000"), [y[0]]),
1200            MatchRow::new(Const::lit("111"), [y[0]]),
1201            MatchRow::new(Const::lit("010"), [y[1]]),
1202            MatchRow::new(Const::lit("XXX"), []),
1203        ]);
1204    }
1205
1206    #[test]
1207    fn test_match_cell_into_matrix_nested() {
1208        let mut design = Design::new();
1209        let a = design.add_input("a", 3);
1210        let b = design.add_input("b", 2);
1211        let ya = design.add_match(MatchCell {
1212            value: a.clone(),
1213            enable: Net::ONE,
1214            patterns: vec![vec![Const::lit("000"), Const::lit("111")], vec![Const::lit("010")]],
1215        });
1216        let yb = design.add_match(MatchCell {
1217            value: b.clone(),
1218            enable: ya[1],
1219            patterns: vec![vec![Const::lit("00")], vec![Const::lit("11")]],
1220        });
1221        let yya = design.add_buf(&ya);
1222        let yyb = design.add_buf(&yb);
1223        design.add_output("ya", &yya);
1224        design.add_output("yb", &yyb);
1225        design.apply();
1226
1227        let ya_cell = design.find_cell(ya[0]).0;
1228        let mut match_cells = Vec::new();
1229        let ml = MatchTrees::build(&design).cell_into_matrix(ya_cell, &mut match_cells);
1230        assert_eq!(match_cells.len(), 2);
1231        assert_eq!(match_cells[0].output(), ya);
1232        assert_eq!(match_cells[1].output(), yb);
1233        design.apply();
1234
1235        let ya_cell = design.find_cell(yya[0]).0;
1236        let yb_cell = design.find_cell(yyb[0]).0;
1237
1238        let Cell::Buf(ya) = &*ya_cell.get() else { unreachable!() };
1239        let Cell::Buf(yb) = &*yb_cell.get() else { unreachable!() };
1240        let mut mr = MatchMatrix::new(a.concat(b));
1241        mr.add(MatchRow::new(Const::lit("XX000"), [ya[0]]));
1242        mr.add(MatchRow::new(Const::lit("XX111"), [ya[0]]));
1243        mr.add(MatchRow::new(Const::lit("00010"), [ya[1], yb[0]]));
1244        mr.add(MatchRow::new(Const::lit("11010"), [ya[1], yb[1]]));
1245        mr.add(MatchRow::new(Const::lit("XX010"), [ya[1]]));
1246        mr.add(MatchRow::new(Const::lit("XXXXX"), []));
1247
1248        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
1249    }
1250
1251    fn assign(value: impl Into<Value>, enable: impl Into<Net>, update: impl Into<Value>) -> AssignCell {
1252        AssignCell { value: value.into(), enable: enable.into(), update: update.into(), offset: 0 }
1253    }
1254
1255    #[test]
1256    fn test_assign_chains_build_1() {
1257        let mut design = Design::new();
1258        let x = design.add_input("x", 4);
1259        let _a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1260        design.apply();
1261
1262        let AssignChains { chains } = AssignChains::build(&design);
1263
1264        assert!(chains.is_empty());
1265    }
1266
1267    #[test]
1268    fn test_assign_chains_build_2() {
1269        let mut design = Design::new();
1270        let x = design.add_input("x", 4);
1271        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1272        let y = design.add_input("y", 4);
1273        let a2 = design.add_assign(assign(a1.clone(), Net::ONE, y));
1274        design.apply();
1275
1276        let a1_cell = design.find_cell(a1[0]).0;
1277        let a2_cell = design.find_cell(a2[0]).0;
1278        let AssignChains { chains } = AssignChains::build(&design);
1279
1280        assert!(chains == [vec![a1_cell, a2_cell]]);
1281    }
1282
1283    #[test]
1284    fn test_assign_chains_build_3_fork() {
1285        let mut design = Design::new();
1286        let x = design.add_input("x", 4);
1287        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1288        let y = design.add_input("y", 4);
1289        let _a2 = design.add_assign(assign(a1.clone(), Net::ONE, y));
1290        let z = design.add_input("z", 4);
1291        let _a3 = design.add_assign(assign(a1.clone(), Net::ONE, z));
1292        design.apply();
1293
1294        let AssignChains { chains } = AssignChains::build(&design);
1295
1296        assert!(chains.is_empty());
1297    }
1298
1299    #[test]
1300    fn test_assign_chains_build_partial_update() {
1301        let mut design = Design::new();
1302        let x = design.add_input("x", 4);
1303        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1304        let y = design.add_input("y", 3);
1305        let _a2 = design.add_assign(AssignCell { value: a1.clone(), enable: Net::ONE, update: y, offset: 1 });
1306        design.apply();
1307
1308        let AssignChains { chains } = AssignChains::build(&design);
1309
1310        assert!(chains.is_empty());
1311    }
1312
1313    #[test]
1314    fn test_assign_chains_build_partial_value() {
1315        let mut design = Design::new();
1316        let x = design.add_input("x", 4);
1317        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1318        let y = design.add_input("y", 3);
1319        let _a2 = design.add_assign(assign(a1.slice(..3), Net::ONE, y));
1320        design.apply();
1321
1322        let AssignChains { chains } = AssignChains::build(&design);
1323
1324        assert!(chains.is_empty());
1325    }
1326
1327    #[test]
1328    fn test_assign_lower_disjoint() {
1329        let mut dl = Design::new();
1330        let c = dl.add_input("c", 2);
1331        let m = dl.add_match(MatchCell {
1332            value: c.clone(),
1333            enable: Net::ONE,
1334            patterns: vec![
1335                vec![
1336                    Const::lit("00"), // x1
1337                    Const::lit("11"), // x1
1338                ],
1339                vec![Const::lit("01")], // x2
1340                vec![Const::lit("10")], // x3
1341            ],
1342        });
1343        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1344        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1345        let a3 = dl.add_assign(assign(a2, m[2], dl.add_input("x3", 4)));
1346        dl.add_output("y", a3);
1347        dl.apply();
1348
1349        decision(&mut dl);
1350
1351        let mut dr = Design::new();
1352        let c = dr.add_input("c", 2);
1353        let x1 = dr.add_input("x1", 4);
1354        let x2 = dr.add_input("x2", 4);
1355        let x3 = dr.add_input("x3", 4);
1356        let m1 = dr.add_mux(c[1], &x3, &x1);
1357        let m2 = dr.add_mux(c[1], &x1, &x2);
1358        let m3 = dr.add_mux(c[0], m2, m1);
1359        dr.add_output("y", m3);
1360
1361        assert_isomorphic!(dl, dr);
1362    }
1363
1364    #[test]
1365    fn test_assign_lower_disjoint_partial() {
1366        let mut dl = Design::new();
1367        let c = dl.add_input("c", 2);
1368        let m = dl.add_match(MatchCell {
1369            value: c.clone(),
1370            enable: Net::ONE,
1371            patterns: vec![
1372                vec![
1373                    Const::lit("00"), // x1
1374                    Const::lit("11"), // x1
1375                ],
1376                vec![Const::lit("01")], // x2
1377                vec![Const::lit("10")], // x3
1378            ],
1379        });
1380        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1381        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1382        let a3 = dl.add_assign(assign(a2, m[2], dl.add_input("x3", 3)));
1383        dl.add_output("y", a3);
1384        dl.apply();
1385
1386        decision(&mut dl);
1387        // the particular output generated here is uninteresting, assert that
1388        // lowering doesn't panic and is accepted by SMT
1389    }
1390
1391    #[test]
1392    fn test_assign_lower_disjoint_child() {
1393        let mut dl = Design::new();
1394        let c1 = dl.add_input("c1", 1);
1395        let m1 = dl.add_match(MatchCell {
1396            value: c1,
1397            enable: Net::ONE,
1398            patterns: vec![
1399                vec![Const::lit("0")], // m2
1400                vec![Const::lit("1")], // x2
1401            ],
1402        });
1403
1404        let c2 = dl.add_input("c2", 2);
1405        let m2 = dl.add_match(MatchCell {
1406            value: c2,
1407            enable: m1[0],
1408            patterns: vec![
1409                vec![Const::lit("01")], // x1
1410            ],
1411        });
1412
1413        let a1 = dl.add_assign(assign(Value::zero(4), m2[0], dl.add_input("x1", 4)));
1414        let a2 = dl.add_assign(assign(a1, m1[1], dl.add_input("x2", 4)));
1415        dl.add_output("y", a2);
1416        dl.apply();
1417
1418        decision(&mut dl);
1419
1420        let mut dr = Design::new();
1421        let c1 = dr.add_input("c1", 1);
1422        let c2 = dr.add_input("c2", 2);
1423        let x1 = dr.add_input("x1", 4);
1424        let x2 = dr.add_input("x2", 4);
1425        let m1 = dr.add_mux(c2[1], Value::zero(4), &x1);
1426        let m2 = dr.add_mux(c2[0], &m1, Value::zero(4));
1427        let m3 = dr.add_mux(c1[0], &x2, &m2);
1428        dr.add_output("y", m3);
1429
1430        assert_isomorphic!(dl, dr);
1431    }
1432
1433    #[test]
1434    fn test_assign_lower_overlapping() {
1435        let mut dl = Design::new();
1436        let c = dl.add_input("c", 1);
1437        let m = dl.add_match(MatchCell {
1438            value: c.clone(),
1439            enable: Net::ONE,
1440            patterns: vec![vec![Const::lit("0")], vec![Const::lit("1")]],
1441        });
1442        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1443        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1444        let a3 = dl.add_assign(assign(a2, m[1], dl.add_input("x3", 4)));
1445        dl.add_output("y", a3);
1446        dl.apply();
1447
1448        decision(&mut dl);
1449
1450        let mut dr = Design::new();
1451        let c = dr.add_input1("c");
1452        let x1 = dr.add_input("x1", 4);
1453        let x2 = dr.add_input("x2", 4);
1454        let x3 = dr.add_input("x3", 4);
1455        let mc = dr.add_mux(c, Const::lit("10"), Const::lit("01"));
1456        let m2 = dr.add_mux(c, x2, x1);
1457        let m3 = dr.add_mux(mc[1], x3, m2);
1458        dr.add_output("y", m3);
1459
1460        assert_isomorphic!(dl, dr);
1461    }
1462
1463    #[test]
1464    fn test_assign_lower_different_matches() {
1465        let mut dl = Design::new();
1466        let c1 = dl.add_input("c1", 1);
1467        let c2 = dl.add_input("c2", 1);
1468        let m1 = dl.add_match(MatchCell { value: c1, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1469        let m2 = dl.add_match(MatchCell { value: c2, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1470        let a1 = dl.add_assign(assign(Value::zero(4), m1[0], dl.add_input("x1", 4)));
1471        let a2 = dl.add_assign(assign(a1, m2[0], dl.add_input("x2", 4)));
1472        dl.add_output("y", a2);
1473        dl.apply();
1474
1475        decision(&mut dl);
1476
1477        let mut dr = Design::new();
1478        let c1 = dr.add_input1("c1");
1479        let c2 = dr.add_input1("c2");
1480        let mc2 = dr.add_mux(c2, Const::lit("0"), Const::lit("1"));
1481        let m1 = dr.add_mux(c1, Value::zero(4), dr.add_input("x1", 4));
1482        let m2 = dr.add_mux(mc2[0], dr.add_input("x2", 4), m1);
1483        dr.add_output("y", m2);
1484
1485        assert_isomorphic!(dl, dr);
1486    }
1487
1488    #[test]
1489    fn test_assign_lower_partial() {
1490        let mut dl = Design::new();
1491        let en = dl.add_input1("en");
1492        let assign = dl.add_assign(AssignCell {
1493            value: dl.add_input("value", 6),
1494            enable: en,
1495            update: dl.add_input("update", 3),
1496            offset: 2,
1497        });
1498        dl.add_output("assign", assign);
1499        dl.apply();
1500
1501        decision(&mut dl);
1502
1503        let mut dr = Design::new();
1504        let en = dr.add_input1("en");
1505        let value = dr.add_input("value", 6);
1506        let update = dr.add_input("update", 3);
1507        let mux = dr.add_mux(en, update, value.slice(2..5));
1508        dr.add_output("assign", value.slice(..2).concat(mux.concat(value.slice(5..))));
1509
1510        assert_isomorphic!(dl, dr);
1511    }
1512
1513    #[test]
1514    fn test_match_eq_refinement() {
1515        let mut design = Design::new();
1516        let a = design.add_input("a", 2);
1517        let m = design.add_match(MatchCell {
1518            value: a,
1519            enable: Net::ONE,
1520            patterns: vec![vec![Const::lit("00")], vec![Const::lit("XX")]],
1521        });
1522        design.add_output("y", m);
1523        design.apply();
1524
1525        decision(&mut design);
1526        // the particular output generated here is uninteresting, assert that
1527        // lowering doesn't panic and is accepted by SMT
1528    }
1529}