prjunnamed_generic/
decision.rs

1//! Decision tree lowering.
2//!
3//! The `decision` pass implements decision tree lowering based on the well-known heuristic
4//! algorithm developed for ML from the paper "Tree pattern matching for ML" by Marianne Baudinet
5//! and David MacQueen (unpublished, 1985) the extended abstract of which is available from:
6//! *  <https://smlfamily.github.io/history/Baudinet-DM-tree-pat-match-12-85.pdf> (scan)
7//! *  <https://www.classes.cs.uchicago.edu/archive/2011/spring/22620-1/papers/macqueen-baudinet85.pdf> (OCR)
8//!
9//! The algorithm is described in ยง4 "Decision trees and the dispatching problem". Only two
10//! of the heuristics described apply here: the relevance heuristic and the branching factor
11//! heuristic.
12
13use std::fmt::Display;
14use std::collections::{BTreeMap, BTreeSet};
15use std::rc::Rc;
16
17use prjunnamed_netlist::{AssignCell, Cell, CellRef, Const, Design, MatchCell, Net, Trit, Value};
18
19/// Maps `pattern` (a constant where 0 and 1 match the respective states, and X matches any state)
20/// to a set of `rules` (the nets that are asserted if the `pattern` matches the value being tested).
21#[derive(Debug, Clone, PartialEq, Eq)]
22struct MatchRow {
23    pattern: Const,
24    rules: BTreeSet<Net>,
25}
26
27/// Matches `value` against ordered `rows` of patterns and their corresponding rules, where the `rules`
28/// for the first pattern that matches the value being tested are asserted, and all other `rules`
29/// are deasserted.
30///
31/// Invariant: once a matrix is constructed, there is always a case that matches.
32/// Note that due to the possiblity of inputs being `X`, this is only
33/// satisfiable with a catch-all row.
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct MatchMatrix {
36    value: Value,
37    rows: Vec<MatchRow>,
38}
39
40/// Describes the process of testing individual nets of a value (equivalently, eliminating columns
41/// from a match matrix) until a specific row is reached.
42#[derive(Debug, Clone, PartialEq, Eq)]
43enum Decision {
44    /// Drive a set of nets to 1.
45    Result { rules: BTreeSet<Net> },
46    /// Branch on the value of a particular net.
47    Branch { test: Net, if0: Box<Decision>, if1: Box<Decision> },
48}
49
50impl MatchRow {
51    fn new(pattern: impl Into<Const>, rules: impl IntoIterator<Item = Net>) -> Self {
52        Self { pattern: pattern.into(), rules: BTreeSet::from_iter(rules) }
53    }
54
55    fn empty(pattern_len: usize) -> Self {
56        Self::new(Const::undef(pattern_len), [])
57    }
58
59    fn merge(mut self, other: &MatchRow) -> Self {
60        self.pattern.extend(&other.pattern);
61        self.rules.extend(other.rules.iter().cloned());
62        self
63    }
64}
65
66impl MatchMatrix {
67    fn new(value: impl Into<Value>) -> Self {
68        Self { value: value.into(), rows: Vec::new() }
69    }
70
71    fn add(&mut self, row: MatchRow) -> usize {
72        assert_eq!(self.value.len(), row.pattern.len());
73        self.rows.push(row);
74        self.rows.len() - 1
75    }
76
77    fn add_enable(&mut self, enable: Net) {
78        if enable != Net::ONE {
79            for row in &mut self.rows {
80                row.pattern.push(Trit::One);
81            }
82            self.rows.insert(0, MatchRow::new(Const::undef(self.value.len()).concat(Trit::Zero), []));
83            self.value.push(enable);
84        }
85    }
86
87    /// Merge in a `MatchMatrix` describing a child `match` cell whose `enable`
88    /// input is being driven by `at`.
89    fn merge(mut self, at: Net, other: &MatchMatrix) -> Self {
90        self.value.extend(&other.value);
91        for self_row in std::mem::take(&mut self.rows) {
92            if self_row.rules.contains(&at) {
93                for other_row in &other.rows {
94                    self.add(self_row.clone().merge(&other_row));
95                }
96            } else {
97                self.add(self_row.merge(&MatchRow::empty(other.value.len())));
98            }
99        }
100        self
101    }
102
103    fn iter_outputs(&self) -> impl Iterator<Item = Net> {
104        let mut outputs: Vec<Net> = self.rows.iter().flat_map(|row| row.rules.iter().copied()).collect();
105        outputs.sort();
106        outputs.dedup();
107        outputs.into_iter()
108    }
109
110    fn assume(mut self, target: Net, value: Trit) -> Self {
111        self.value =
112            Value::from_iter(self.value.into_iter().map(|net| if net == target { Net::from(value) } else { net }));
113        self
114    }
115
116    /// Remove redundant rows and columns.
117    ///
118    /// This function ensures the following normal-form properties:
119    /// - `self.value` does not contain any constant nets
120    /// - all nets occurs within `self.value` at most once
121    /// - a row of all `X` can only occur at the very end
122    /// - no two identical rows can occur immediately next to each other
123    ///
124    /// Note that the last of these properties is relatively weak, in that
125    /// stronger properties exist which can still be feasibly checked.
126    fn normalize(mut self) -> Self {
127        let mut remove_cols = BTreeSet::new();
128        let mut remove_rows = BTreeSet::new();
129
130        // Pick columns to remove where the matched value is constant or has repeated nets.
131
132        // For each `Net`, store the index of the first column being driven
133        // by that net.
134        let mut first_at = BTreeMap::new();
135        for (index, net) in self.value.iter().enumerate() {
136            if net.is_cell() && !first_at.contains_key(&net) {
137                first_at.insert(net, index);
138            } else {
139                remove_cols.insert(index);
140            }
141        }
142
143        // Pick rows to remove that:
144        // * are redundant with the immediately preceeding row, or
145        // * contradict themselves or the constant nets in matched value.
146        let mut prev_pattern = None;
147        'outer: for (row_index, row) in self.rows.iter_mut().enumerate() {
148            // Check if row will never match because of a previous row that:
149            // * has the same pattern, or
150            // * matches any value.
151            if let Some(ref prev_pattern) = prev_pattern {
152                if row.pattern == *prev_pattern || prev_pattern.is_undef() {
153                    remove_rows.insert(row_index);
154                    continue;
155                }
156            }
157            prev_pattern = Some(row.pattern.clone());
158            // Check if row is contradictory.
159            for (net_index, net) in self.value.iter().enumerate() {
160                let mask = row.pattern[net_index];
161                // Row contradicts constant in matched value.
162                //
163                // Note that if we're matching against a constant `X`, removing
164                // the row is nevertheless valid, since the row isn't guaranteed
165                // to match regardless of the value of the `X`, and so we can
166                // refine it into "doesn't match".
167                match net.as_const() {
168                    Some(trit) if trit != mask && mask != Trit::Undef => {
169                        remove_rows.insert(row_index);
170                        continue 'outer;
171                    }
172                    _ => (),
173                }
174                // Check if the net appears multiple times in the matched value.
175                match first_at.get(&net) {
176                    // It doesn't.
177                    None => (),
178                    // It does and this is the first occurrence. Leave it alone.
179                    Some(&first_index) if first_index == net_index => (),
180                    // It does and this is the second or later occurrence. Check if it is compatible with
181                    // the first one. Also, if the first one was a don't-care, move this one into the position
182                    // of the first one.
183                    Some(&first_index) => {
184                        let first_mask = &mut row.pattern[first_index];
185                        if *first_mask != Trit::Undef && mask != Trit::Undef && *first_mask != mask {
186                            remove_rows.insert(row_index);
187                            continue 'outer;
188                        }
189                        if *first_mask == Trit::Undef {
190                            *first_mask = mask;
191                        }
192                    }
193                }
194            }
195        }
196
197        // Pick columns to remove where all of the patterns match any value.
198        let mut all_undef = vec![true; self.value.len()];
199        for (row_index, row) in self.rows.iter().enumerate() {
200            if remove_rows.contains(&row_index) {
201                continue;
202            }
203            for col_index in 0..self.value.len() {
204                if row.pattern[col_index] != Trit::Undef {
205                    all_undef[col_index] = false;
206                }
207            }
208        }
209        for (col_index, matches_any) in all_undef.into_iter().enumerate() {
210            if matches_any {
211                remove_cols.insert(col_index);
212            }
213        }
214
215        // Execute column and row removal.
216        fn remove_indices<'a, T>(
217            iter: impl IntoIterator<Item = T> + 'a,
218            remove_set: &'a BTreeSet<usize>,
219        ) -> impl Iterator<Item = T> + 'a {
220            iter.into_iter().enumerate().filter_map(|(index, elem)| (!remove_set.contains(&index)).then_some(elem))
221        }
222
223        self.value = Value::from_iter(remove_indices(self.value, &remove_cols));
224        self.rows = Vec::from_iter(remove_indices(self.rows, &remove_rows));
225        for row in &mut self.rows {
226            row.pattern = Const::from_iter(remove_indices(row.pattern.iter(), &remove_cols));
227        }
228        self
229    }
230
231    /// Construct a decision tree for the match matrix.
232    fn dispatch(mut self) -> Decision {
233        self = self.normalize();
234        if self.value.is_empty() || self.rows.len() == 1 {
235            Decision::Result { rules: self.rows.into_iter().next().map(|r| r.rules).unwrap_or_default() }
236        } else {
237            // Fanfiction of the heuristics from the 1986 paper that reduces them to: split the matrix on the column
238            // with the fewest don't-care's in it.
239            let mut undef_count = vec![0; self.value.len()];
240            for row in self.rows.iter() {
241                for (col_index, mask) in row.pattern.iter().enumerate() {
242                    if mask == Trit::Undef {
243                        undef_count[col_index] += 1;
244                    }
245                }
246            }
247            let test_index = (0..self.value.len()).min_by_key(|&i| undef_count[i]);
248            let test = self.value[test_index.unwrap()];
249
250            // Split the matrix into two, where the test net has a value of 0 and 1, and recurse.
251            let if0 = self.clone().assume(test, Trit::Zero).dispatch();
252            let if1 = self.assume(test, Trit::One).dispatch();
253            if if0 == if1 {
254                // Skip the branch if the outputs of the decision function are the same. This can happen
255                // e.g. in the following matrix:
256                //   00 => x
257                //   10 => x
258                //   XX => y
259                // regardless of the column selection order. This is readily apparent when the left-hand
260                // column is split on first, but even if the right-hand column is chosen, there will a
261                // case-split with both arms leading to `Decision::Result(x)`.
262                if0
263            } else {
264                Decision::Branch { test, if0: if0.into(), if1: if1.into() }
265            }
266        }
267    }
268}
269
270impl Decision {
271    /// Call `f` on the contents of each `Decision::Result` in this tree.
272    fn each_leaf(&self, f: &mut impl FnMut(&BTreeSet<Net>)) {
273        match self {
274            Decision::Result { rules } => f(rules),
275            Decision::Branch { if0, if1, .. } => {
276                if0.each_leaf(f);
277                if1.each_leaf(f);
278            }
279        }
280    }
281
282    /// Emit a mux-tree that outputs `values[x]` when net `x` would be `1`
283    /// according to the decision tree.
284    ///
285    /// Assumes that each case within `values` is mutually exclusive. Panics if
286    /// that is not the case.
287    fn emit_disjoint_mux(&self, design: &Design, values: &BTreeMap<Net, Value>, default: &Value) -> Value {
288        match self {
289            Decision::Result { rules } => {
290                let mut result = None;
291                for rule in rules {
292                    if let Some(value) = values.get(&rule) {
293                        assert!(result.is_none());
294                        result = Some(value.clone());
295                    }
296                }
297                result.unwrap_or(default.clone())
298            }
299            Decision::Branch { test, if0, if1 } => design.add_mux(
300                *test,
301                if1.emit_disjoint_mux(design, values, default),
302                if0.emit_disjoint_mux(design, values, default),
303            ),
304        }
305    }
306
307    /// Emit a mux-tree that drives the `nets` according to the decision tree.
308    fn emit_one_hot_mux(&self, design: &Design, nets: &Value) -> Value {
309        match self {
310            Decision::Result { rules } => Value::from_iter(
311                nets.iter().map(|net| if rules.contains(&net) { Trit::One } else { Trit::Zero }.into()),
312            ),
313            Decision::Branch { test, if0, if1 } => {
314                design.add_mux(*test, if1.emit_one_hot_mux(design, nets), if0.emit_one_hot_mux(design, nets))
315            }
316        }
317    }
318}
319
320struct MatchTrees<'a> {
321    design: &'a Design,
322    /// Set of all `match` cells that aren't children. A `match` cell is a child
323    /// of another `match` cell if its `enable` input is being driven from
324    /// the output of the parent, and it is the unique such `match` cell for
325    /// that particular output bit. That is, if it is possible to merge the
326    /// child into the same decision tree.
327    roots: BTreeSet<CellRef<'a>>,
328    /// Maps a particular output of a `match` cell to the child `match` cell
329    /// whose `enable` input it is driving.
330    subtrees: BTreeMap<(CellRef<'a>, usize), CellRef<'a>>,
331}
332
333impl<'a> MatchTrees<'a> {
334    /// Recognize a tree of `match` cells, connected by their enable inputs.
335    fn build(design: &'a Design) -> MatchTrees<'a> {
336        let mut roots: BTreeSet<CellRef> = BTreeSet::new();
337        let mut subtrees: BTreeMap<(CellRef, usize), BTreeSet<CellRef>> = BTreeMap::new();
338        for cell_ref in design.iter_cells() {
339            let Cell::Match(MatchCell { enable, .. }) = &*cell_ref.get() else { continue };
340            if let Ok((enable_cell_ref, offset)) = design.find_cell(*enable) {
341                if let Cell::Match(_) = &*enable_cell_ref.get() {
342                    // Driven by a match cell; may be a subtree or a root depending on its fanout.
343                    subtrees.entry((enable_cell_ref, offset)).or_default().insert(cell_ref);
344                    continue;
345                }
346            }
347            // Driven by some other cell or a constant; is a root.
348            roots.insert(cell_ref);
349        }
350
351        // Whenever multiple subtrees are connected to the same one-hot output, it is not possible
352        // to merge all of them into the same matrix. Turn all of these subtrees into roots.
353        let subtrees = subtrees
354            .into_iter()
355            .filter_map(|(key, subtrees)| {
356                if subtrees.len() == 1 {
357                    Some((key, subtrees.into_iter().next().unwrap()))
358                } else {
359                    roots.extend(subtrees);
360                    None
361                }
362            })
363            .collect();
364
365        Self { design, roots, subtrees }
366    }
367
368    /// Convert a tree of `match` cells into a matrix.
369    ///
370    /// Collects a list of all the cells being lifted into the matrix into
371    /// `all_cell_refs`.
372    ///
373    /// Replaces outputs that don't have any patterns at all with `Net::ZERO`,
374    /// but otherwise doesn't modify the design.
375    fn cell_into_matrix(&self, cell_ref: CellRef<'a>, all_cell_refs: &mut Vec<CellRef<'a>>) -> MatchMatrix {
376        let Cell::Match(match_cell) = &*cell_ref.get() else { unreachable!() };
377        let output = cell_ref.output();
378        all_cell_refs.push(cell_ref);
379
380        // Create matrix for this cell.
381        let mut matrix = MatchMatrix::new(&match_cell.value);
382        for (output_net, alternates) in output.iter().zip(match_cell.patterns.iter()) {
383            for pattern in alternates {
384                matrix.add(MatchRow::new(pattern.clone(), [output_net]));
385            }
386            if alternates.is_empty() {
387                self.design.replace_net(output_net, Net::ZERO);
388            }
389        }
390        matrix.add(MatchRow::empty(match_cell.value.len()));
391
392        // Create matrices for subtrees and merge them into the matrix for this cell.
393        for (offset, output_net) in output.iter().enumerate() {
394            if let Some(&sub_cell_ref) = self.subtrees.get(&(cell_ref, offset)) {
395                matrix = matrix.merge(output_net, &self.cell_into_matrix(sub_cell_ref, all_cell_refs));
396            }
397        }
398
399        matrix
400    }
401
402    /// For each tree of `match` cells, return a corresponding `MatchMatrix`
403    /// and a list of `match` cells that this matrix implements.
404    fn iter_matrices<'b>(&'b self) -> impl Iterator<Item = (MatchMatrix, Vec<CellRef<'b>>)> + 'b {
405        self.roots.iter().map(|&cell_ref| {
406            let Cell::Match(MatchCell { enable, .. }) = &*cell_ref.get() else { unreachable!() };
407            let mut all_cell_refs = Vec::new();
408            let mut matrix = self.cell_into_matrix(cell_ref, &mut all_cell_refs);
409            matrix.add_enable(*enable);
410            (matrix, all_cell_refs)
411        })
412    }
413}
414
415struct AssignChains<'a> {
416    chains: Vec<Vec<CellRef<'a>>>,
417}
418
419impl<'a> AssignChains<'a> {
420    fn build(design: &'a Design) -> AssignChains<'a> {
421        let mut roots: BTreeSet<CellRef> = BTreeSet::new();
422        let mut links: BTreeMap<CellRef, BTreeSet<CellRef>> = BTreeMap::new();
423        for cell_ref in design.iter_cells() {
424            let Cell::Assign(AssignCell { value, offset: 0, update, .. }) = &*cell_ref.get() else { continue };
425            if update.len() != value.len() {
426                continue;
427            }
428            if let Ok((value_cell_ref, _offset)) = design.find_cell(value[0]) {
429                if value_cell_ref.output() == *value {
430                    if let Cell::Assign(_) = &*value_cell_ref.get() {
431                        links.entry(value_cell_ref).or_default().insert(cell_ref);
432                        continue;
433                    }
434                }
435            }
436            roots.insert(cell_ref);
437        }
438
439        let mut chains = Vec::new();
440        for root in roots {
441            let mut chain = vec![root];
442            while let Some(links) = links.get(&chain.last().unwrap()) {
443                if links.len() == 1 {
444                    chain.push(*links.first().unwrap());
445                } else {
446                    break;
447                }
448            }
449            if chain.len() > 1 {
450                chains.push(chain);
451            }
452        }
453
454        Self { chains }
455    }
456
457    fn iter_disjoint<'b>(
458        &'b self,
459        decisions: &'a BTreeMap<Net, Rc<Decision>>,
460        occurrences: &BTreeMap<Net, Vec<u32>>,
461    ) -> impl Iterator<Item = (Rc<Decision>, &'b [CellRef<'a>])> {
462        fn enable_of(cell_ref: CellRef) -> Net {
463            let Cell::Assign(AssignCell { enable, .. }) = &*cell_ref.get() else { unreachable!() };
464            *enable
465        }
466
467        self.chains.iter().filter_map(|chain| {
468            let mut used_branches = BTreeSet::new();
469            // Add all branches driving `net` to `used_branches`. Returns
470            // `false` if this is a conflict (i.e. the nets aren't mutually
471            // exclusive).
472            let mut consume_branches = |net: Net| -> bool {
473                let Some(occurs) = occurrences.get(&net) else {
474                    // net is not driven by any branches in this decision tree.
475                    // this can happen if a pattern turns out to be impossible
476                    // (e.g. due to constant propagation)
477                    return true;
478                };
479
480                for &occurrence in occurs {
481                    if !used_branches.insert(occurrence) {
482                        return false;
483                    }
484                }
485
486                true
487            };
488
489            // Check if the enables belong to disjoint branches within the same decision tree
490            // (like in a SystemVerilog "unique" or "unique0" statement).
491            let enable = enable_of(chain[0]);
492            let decision = decisions.get(&enable)?;
493            assert!(consume_branches(enable));
494            let mut end_index = chain.len();
495            'chain: for (index, &other_cell) in chain.iter().enumerate().skip(1) {
496                let enable = enable_of(other_cell);
497                let other_decision = decisions.get(&enable)?;
498                if !Rc::ptr_eq(decision, other_decision) || !consume_branches(enable) {
499                    end_index = index;
500                    break 'chain;
501                }
502            }
503            let chain = &chain[..end_index];
504
505            Some((decision.clone(), chain))
506        })
507    }
508}
509
510pub fn decision(design: &mut Design) {
511    // Detect and extract trees of `match` cells present in the netlist.
512    let match_trees = MatchTrees::build(design);
513
514    // Detect and extract chains of `assign` statements without slicing.
515    let assign_chains = AssignChains::build(design);
516
517    // Combine each tree of `match` cells into a single match matrix.
518    // Then build a decision tree for it and use it to drive the output.
519    let mut decisions: BTreeMap<Net, Rc<Decision>> = BTreeMap::new();
520
521    let mut next_branch: u32 = 0;
522    let mut occurrences: BTreeMap<Net, Vec<u32>> = BTreeMap::new();
523
524    for (matrix, matches) in match_trees.iter_matrices() {
525        let all_outputs = BTreeSet::from_iter(matrix.iter_outputs());
526        if cfg!(feature = "trace") {
527            eprint!(">matrix:\n{matrix}");
528        }
529
530        let decision = Rc::new(matrix.dispatch());
531        if cfg!(feature = "trace") {
532            eprint!(">decision tree:\n{decision}")
533        }
534
535        decision.each_leaf(&mut |outputs| {
536            let branch = next_branch;
537            next_branch += 1;
538
539            for &output in outputs {
540                occurrences.entry(output).or_default().push(branch);
541            }
542        });
543
544        for &output in &all_outputs {
545            decisions.insert(output, decision.clone());
546        }
547
548        let _guard = design.use_metadata_from(&matches[..]);
549        let nets = Value::from_iter(all_outputs);
550        design.replace_value(&nets, decision.emit_one_hot_mux(design, &nets));
551    }
552
553    // Find chains of `assign` cells that are order-independent.
554    // Then lower these cells to a `mux` tree without `eq` cells.
555    let mut used_assigns = BTreeSet::new();
556    for (decision, chain) in assign_chains.iter_disjoint(&decisions, &occurrences) {
557        let (first_assign, last_assign) = (chain.first().unwrap(), chain.last().unwrap());
558        if cfg!(feature = "trace") {
559            eprintln!(">disjoint:");
560            for &cell_ref in chain {
561                eprintln!("{}", design.display_cell(cell_ref));
562            }
563        }
564
565        let mut values = BTreeMap::new();
566        let Cell::Assign(AssignCell { value: default, .. }) = &*first_assign.get() else { unreachable!() };
567        for &cell_ref in chain {
568            let Cell::Assign(AssignCell { enable, update, .. }) = &*cell_ref.get() else { unreachable!() };
569            values.insert(*enable, update.clone());
570        }
571
572        let _guard = design.use_metadata_from(&chain[..]);
573        design.replace_value(last_assign.output(), decision.emit_disjoint_mux(design, &values, default));
574        used_assigns.insert(*last_assign);
575    }
576
577    // Lower other `assign` cells.
578    for cell_ref in design.iter_cells().filter(|cell_ref| !used_assigns.contains(cell_ref)) {
579        let Cell::Assign(assign_cell) = &*cell_ref.get() else { continue };
580        if cfg!(feature = "trace") {
581            eprintln!(">chained: {}", design.display_cell(cell_ref));
582        }
583
584        let _guard = design.use_metadata_from(&[cell_ref]);
585        let mut nets = Vec::from_iter(assign_cell.value.iter());
586        let slice = assign_cell.offset..(assign_cell.offset + assign_cell.update.len());
587        nets[slice.clone()].copy_from_slice(
588            &design.add_mux(assign_cell.enable, &assign_cell.update, assign_cell.value.slice(slice))[..],
589        );
590        design.replace_value(cell_ref.output(), Value::from(nets));
591    }
592
593    design.compact();
594}
595
596impl Display for MatchRow {
597    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
598        for (index, trit) in self.pattern.iter().rev().enumerate() {
599            if index != 0 && index % 4 == 0 {
600                write!(f, "_")?;
601            }
602            write!(f, "{trit}")?;
603        }
604        write!(f, " =>")?;
605        if self.rules.is_empty() {
606            return write!(f, " (empty)");
607        }
608        for rule in &self.rules {
609            write!(f, " {rule}")?;
610        }
611        Ok(())
612    }
613}
614
615impl Display for MatchMatrix {
616    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
617        write!(f, "{}:\n", self.value)?;
618        for row in &self.rows {
619            write!(f, "  {row}\n")?;
620        }
621        Ok(())
622    }
623}
624
625impl Decision {
626    fn format(&self, f: &mut std::fmt::Formatter, level: usize) -> std::fmt::Result {
627        let format_rules = |f: &mut std::fmt::Formatter, rules: &BTreeSet<Net>| {
628            if rules.is_empty() {
629                write!(f, " (empty)")
630            } else {
631                for rule in rules {
632                    write!(f, " {rule}")?;
633                }
634                Ok(())
635            }
636        };
637
638        let format_decision = |f: &mut std::fmt::Formatter, net: Net, value: usize, decision: &Decision| {
639            if let Decision::Result { rules } = decision {
640                if rules.is_empty() {
641                    return Ok(());
642                }
643            }
644            for _ in 0..level {
645                write!(f, "  ")?;
646            }
647            match decision {
648                Decision::Result { rules } => {
649                    write!(f, "{net} = {value} =>")?;
650                    format_rules(f, &rules)?;
651                    write!(f, "\n")
652                }
653                Decision::Branch { .. } => {
654                    write!(f, "{net} = {value} =>\n")?;
655                    decision.format(f, level + 1)
656                }
657            }
658        };
659
660        match self {
661            Decision::Result { rules } => {
662                assert_eq!(level, 0);
663                write!(f, "=>")?;
664                format_rules(f, &rules)?;
665                write!(f, "\n")?;
666            }
667            Decision::Branch { test, if0, if1 } => {
668                format_decision(f, *test, 0, if0)?;
669                format_decision(f, *test, 1, if1)?;
670            }
671        }
672        Ok(())
673    }
674}
675
676impl Display for Decision {
677    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
678        self.format(f, 0)
679    }
680}
681
682#[cfg(test)]
683mod test {
684    #![allow(non_snake_case)]
685
686    use std::collections::{BTreeMap, BTreeSet};
687
688    use prjunnamed_netlist::{assert_isomorphic, AssignCell, Cell, Const, Design, MatchCell, Net, Value};
689
690    use super::{decision, AssignChains, Decision, MatchMatrix, MatchRow, MatchTrees};
691
692    struct Helper(Design);
693
694    impl Helper {
695        fn new() -> Self {
696            Self(Design::new())
697        }
698
699        fn val(&self, width: usize) -> Value {
700            self.0.add_void(width)
701        }
702
703        fn net(&self) -> Net {
704            self.0.add_void(1).unwrap_net()
705        }
706
707        fn rs(&self, rule: Net) -> Box<Decision> {
708            Box::new(Decision::Result { rules: BTreeSet::from_iter([rule]) })
709        }
710
711        fn br(&self, test: Net, if1: Box<Decision>, if0: Box<Decision>) -> Box<Decision> {
712            Box::new(Decision::Branch { test, if0, if1 })
713        }
714    }
715
716    #[test]
717    fn test_add_enable() {
718        let h = Helper::new();
719
720        let v = h.val(2);
721        let (n1, n2, en) = (h.net(), h.net(), h.net());
722
723        let mut ml = MatchMatrix::new(&v);
724        ml.add(MatchRow::new(Const::lit("10"), [n1]));
725        ml.add(MatchRow::new(Const::lit("01"), [n2]));
726
727        let mut mr = MatchMatrix::new(&v.concat(en));
728        mr.add(MatchRow::new(Const::lit("0XX"), []));
729        mr.add(MatchRow::new(Const::lit("110"), [n1]));
730        mr.add(MatchRow::new(Const::lit("101"), [n2]));
731
732        ml.add_enable(en);
733        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
734    }
735
736    #[test]
737    fn test_add_enable_trivial() {
738        let h = Helper::new();
739
740        let v = h.val(2);
741        let (n1, n2) = (h.net(), h.net());
742
743        let mut ml = MatchMatrix::new(&v);
744        ml.add(MatchRow::new(Const::lit("10"), [n1]));
745        ml.add(MatchRow::new(Const::lit("01"), [n2]));
746
747        let mr = ml.clone();
748
749        ml.add_enable(Net::ONE);
750        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
751    }
752
753    #[test]
754    fn test_merge_1() {
755        let h = Helper::new();
756
757        let v1 = h.val(2);
758        let (n11, n12) = (h.net(), h.net());
759        let v2 = h.val(3);
760        let (n21, n22) = (h.net(), h.net());
761        let mut m1 = MatchMatrix::new(&v1);
762        m1.add(MatchRow::new(Const::lit("10"), [n11]));
763        m1.add(MatchRow::new(Const::lit("01"), [n12]));
764        m1.add(MatchRow::new(Const::lit("XX"), []));
765
766        let mut m2 = MatchMatrix::new(&v2);
767        m2.add(MatchRow::new(Const::lit("X00"), [n21]));
768        m2.add(MatchRow::new(Const::lit("10X"), [n22]));
769        m2.add(MatchRow::new(Const::lit("XXX"), []));
770
771        let ml1 = m1.clone().merge(n12, &m2);
772
773        let mut mr1 = MatchMatrix::new(v1.concat(&v2));
774        mr1.add(MatchRow::new(Const::lit("XXX10"), [n11]));
775        mr1.add(MatchRow::new(Const::lit("X0001"), [n12, n21]));
776        mr1.add(MatchRow::new(Const::lit("10X01"), [n12, n22]));
777        mr1.add(MatchRow::new(Const::lit("XXX01"), [n12]));
778        mr1.add(MatchRow::new(Const::lit("XXXXX"), []));
779
780        assert_eq!(ml1, mr1, "\n{ml1} != \n{mr1}");
781    }
782
783    #[test]
784    fn test_merge_2() {
785        let h = Helper::new();
786
787        let v1 = h.val(2);
788        let (n11, n12) = (h.net(), h.net());
789        let v2 = h.val(3);
790        let (n21, n22) = (h.net(), h.net());
791        let mut m1 = MatchMatrix::new(&v1);
792        m1.add(MatchRow::new(Const::lit("10"), [n11]));
793        m1.add(MatchRow::new(Const::lit("01"), [n11]));
794        m1.add(MatchRow::new(Const::lit("XX"), [n12]));
795
796        let mut m2 = MatchMatrix::new(&v2);
797        m2.add(MatchRow::new(Const::lit("X00"), [n21]));
798        m2.add(MatchRow::new(Const::lit("10X"), [n22]));
799        m2.add(MatchRow::new(Const::lit("XXX"), []));
800
801        let ml1 = m1.clone().merge(n11, &m2);
802
803        let mut mr1 = MatchMatrix::new(v1.concat(&v2));
804        mr1.add(MatchRow::new(Const::lit("X0010"), [n11, n21]));
805        mr1.add(MatchRow::new(Const::lit("10X10"), [n11, n22]));
806        mr1.add(MatchRow::new(Const::lit("XXX10"), [n11]));
807        mr1.add(MatchRow::new(Const::lit("X0001"), [n11, n21]));
808        mr1.add(MatchRow::new(Const::lit("10X01"), [n11, n22]));
809        mr1.add(MatchRow::new(Const::lit("XXX01"), [n11]));
810        mr1.add(MatchRow::new(Const::lit("XXXXX"), [n12]));
811
812        assert_eq!(ml1, mr1, "\n{ml1} != \n{mr1}");
813    }
814
815    #[test]
816    fn test_normalize_vertical() {
817        let h = Helper::new();
818        let n = h.net();
819
820        let mut m00 = MatchMatrix::new(&Value::from(Const::lit("0")));
821        m00.add(MatchRow::new(Const::lit("0"), [n]));
822
823        let mut m01 = MatchMatrix::new(&Value::from(Const::lit("0")));
824        m01.add(MatchRow::new(Const::lit("1"), [n]));
825
826        let mut m0X = MatchMatrix::new(&Value::from(Const::lit("0")));
827        m0X.add(MatchRow::new(Const::lit("X"), [n]));
828
829        let mut m10 = MatchMatrix::new(&Value::from(Const::lit("1")));
830        m10.add(MatchRow::new(Const::lit("0"), [n]));
831
832        let mut m11 = MatchMatrix::new(&Value::from(Const::lit("1")));
833        m11.add(MatchRow::new(Const::lit("1"), [n]));
834
835        let mut m1X = MatchMatrix::new(&Value::from(Const::lit("1")));
836        m1X.add(MatchRow::new(Const::lit("X"), [n]));
837
838        let mut mX0 = MatchMatrix::new(&Value::from(Const::lit("X")));
839        mX0.add(MatchRow::new(Const::lit("0"), [n]));
840
841        let mut mX1 = MatchMatrix::new(&Value::from(Const::lit("X")));
842        mX1.add(MatchRow::new(Const::lit("1"), [n]));
843
844        let mut mXX = MatchMatrix::new(&Value::from(Const::lit("X")));
845        mXX.add(MatchRow::new(Const::lit("X"), [n]));
846
847        for must_reject in [m01, m10, mX0, mX1] {
848            let normalized = must_reject.clone().normalize();
849            assert_eq!(normalized.rows.len(), 0, "before:\n{must_reject}\nafter:\n{normalized}");
850        }
851        for must_accept in [m00, m0X, m11, m1X, mXX] {
852            let normalized = must_accept.clone().normalize();
853            assert_eq!(normalized.rows.len(), 1, "before:\n{must_accept}\nafter:\n{normalized}");
854            assert_eq!(normalized.rows[0].pattern.len(), 0, "before:\n{must_accept}\nafter:\n{normalized}");
855        }
856    }
857
858    #[test]
859    fn test_normalize_horizontal() {
860        let h = Helper::new();
861        let v = h.val(1);
862        let n = h.net();
863
864        let mut m1 = MatchMatrix::new(v.concat(&v));
865        m1.add(MatchRow::new(Const::lit("0X"), [n]));
866        m1 = m1.normalize();
867        assert_eq!(m1.rows[0].pattern, Const::lit("0"));
868
869        let mut m2 = MatchMatrix::new(v.concat(&v));
870        m2.add(MatchRow::new(Const::lit("X0"), [n]));
871        m2 = m2.normalize();
872        assert_eq!(m2.rows[0].pattern, Const::lit("0"));
873
874        let mut m3 = MatchMatrix::new(v.concat(&v));
875        m3.add(MatchRow::new(Const::lit("10"), [n]));
876        m3 = m3.normalize();
877        assert_eq!(m3.rows.len(), 0);
878    }
879
880    #[test]
881    fn test_normalize_duplicate_row() {
882        let h = Helper::new();
883        let v = h.val(2);
884        let (n1, n2) = (h.net(), h.net());
885
886        let mut m = MatchMatrix::new(v);
887        m.add(MatchRow::new(Const::lit("10"), [n1]));
888        m.add(MatchRow::new(Const::lit("10"), [n2]));
889        m = m.normalize();
890        assert_eq!(m.rows.len(), 1);
891        assert_eq!(m.rows[0].pattern, Const::lit("10"));
892        assert_eq!(m.rows[0].rules, BTreeSet::from_iter([n1]));
893    }
894
895    #[test]
896    fn test_normalize_irrefitable() {
897        let h = Helper::new();
898        let v = h.val(2);
899        let (n1, n2) = (h.net(), h.net());
900
901        let mut m = MatchMatrix::new(v);
902        m.add(MatchRow::new(Const::lit("XX"), [n1]));
903        m.add(MatchRow::new(Const::lit("10"), [n2]));
904        m = m.normalize();
905        assert_eq!(m.rows.len(), 1);
906        assert_eq!(m.rows[0].pattern, Const::lit(""));
907        assert_eq!(m.rows[0].rules, BTreeSet::from_iter([n1]));
908    }
909
910    #[test]
911    fn test_normalize_unused_column() {
912        let h = Helper::new();
913        let v = h.val(2);
914        let (n1, n2) = (h.net(), h.net());
915
916        let mut m = MatchMatrix::new(&v);
917        m.add(MatchRow::new(Const::lit("X0"), [n1]));
918        m.add(MatchRow::new(Const::lit("X1"), [n2]));
919        m = m.normalize();
920        assert_eq!(m.value, v.slice(0..1));
921        assert_eq!(m.rows.len(), 2);
922        assert_eq!(m.rows[0], MatchRow::new(Const::lit("0"), [n1]));
923        assert_eq!(m.rows[1], MatchRow::new(Const::lit("1"), [n2]));
924    }
925
926    #[test]
927    fn test_normalize_unused_column_after_elim() {
928        let h = Helper::new();
929        let v = h.val(2);
930        let (n1, n2, n3) = (h.net(), h.net(), h.net());
931
932        let mut m = MatchMatrix::new(&v.concat(&v));
933        m.add(MatchRow::new(Const::lit("XXX0"), [n1]));
934        m.add(MatchRow::new(Const::lit("XXX1"), [n2]));
935        m.add(MatchRow::new(Const::lit("1X0X"), [n3]));
936        m = m.normalize();
937        assert_eq!(m.value, v.slice(0..1));
938        assert_eq!(m.rows.len(), 2);
939        assert_eq!(m.rows[0], MatchRow::new(Const::lit("0"), [n1]));
940        assert_eq!(m.rows[1], MatchRow::new(Const::lit("1"), [n2]));
941    }
942
943    macro_rules! assert_dispatch {
944        ($m:expr, $d:expr) => {
945            let dl = $m.clone().dispatch();
946            let dr = $d;
947            assert!(dl == *dr, "\ndispatching {}\n{} != \n{}", $m, dl, dr);
948        };
949    }
950
951    #[test]
952    fn test_decide_0() {
953        let h = Helper::new();
954
955        let v = h.val(1);
956        let n = h.net();
957        let mut m = MatchMatrix::new(&v);
958        m.add(MatchRow::new(Const::lit("0"), [n]));
959
960        let d = h.rs(n);
961
962        assert_dispatch!(m, d);
963    }
964
965    #[test]
966    fn test_decide_0_1() {
967        let h = Helper::new();
968
969        let v = h.val(1);
970        let (n1, n2) = (h.net(), h.net());
971        let mut m = MatchMatrix::new(&v);
972        m.add(MatchRow::new(Const::lit("0"), [n1]));
973        m.add(MatchRow::new(Const::lit("1"), [n2]));
974
975        let d = h.br(v[0], h.rs(n2), h.rs(n1));
976
977        assert_dispatch!(m, d);
978    }
979
980    #[test]
981    fn test_decide_0_X() {
982        let h = Helper::new();
983
984        let v = h.val(1);
985        let (n1, n2) = (h.net(), h.net());
986        let mut m = MatchMatrix::new(&v);
987        m.add(MatchRow::new(Const::lit("0"), [n1]));
988        m.add(MatchRow::new(Const::lit("X"), [n2]));
989
990        let d = h.br(v[0], h.rs(n2), h.rs(n1));
991
992        assert_dispatch!(m, d);
993    }
994
995    #[test]
996    fn test_decide_1_X() {
997        let h = Helper::new();
998
999        let v = h.val(1);
1000        let (n1, n2) = (h.net(), h.net());
1001        let mut m = MatchMatrix::new(&v);
1002        m.add(MatchRow::new(Const::lit("1"), [n1]));
1003        m.add(MatchRow::new(Const::lit("X"), [n2]));
1004
1005        let d = h.br(v[0], h.rs(n1), h.rs(n2));
1006
1007        assert_dispatch!(m, d);
1008    }
1009
1010    #[test]
1011    fn test_decide_X_0_1() {
1012        let h = Helper::new();
1013
1014        let v = h.val(1);
1015        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1016        let mut m = MatchMatrix::new(&v);
1017        m.add(MatchRow::new(Const::lit("X"), [n1]));
1018        m.add(MatchRow::new(Const::lit("0"), [n2]));
1019        m.add(MatchRow::new(Const::lit("1"), [n3]));
1020
1021        let d = h.rs(n1);
1022
1023        assert_dispatch!(m, d);
1024    }
1025
1026    #[test]
1027    fn test_decide_0_1_X() {
1028        let h = Helper::new();
1029
1030        let v = h.val(1);
1031        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1032        let mut m = MatchMatrix::new(&v);
1033        m.add(MatchRow::new(Const::lit("0"), [n1]));
1034        m.add(MatchRow::new(Const::lit("1"), [n2]));
1035        m.add(MatchRow::new(Const::lit("X"), [n3]));
1036
1037        let d = h.br(v[0], h.rs(n2), h.rs(n1));
1038
1039        assert_dispatch!(m, d);
1040    }
1041
1042    #[test]
1043    fn test_decide_0X_1X_XX() {
1044        let h = Helper::new();
1045
1046        let v = h.val(2);
1047        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1048        let mut m = MatchMatrix::new(&v);
1049        m.add(MatchRow::new(Const::lit("X0"), [n1]));
1050        m.add(MatchRow::new(Const::lit("X1"), [n2]));
1051        m.add(MatchRow::new(Const::lit("XX"), [n3]));
1052
1053        let d = h.br(v[0], h.rs(n2), h.rs(n1));
1054
1055        assert_dispatch!(m, d);
1056    }
1057
1058    #[test]
1059    fn test_decide_0X_11_XX() {
1060        let h = Helper::new();
1061
1062        let v = h.val(2);
1063        let (n1, n2, n3) = (h.net(), h.net(), h.net());
1064        let mut m = MatchMatrix::new(&v);
1065        m.add(MatchRow::new(Const::lit("0X"), [n1]));
1066        m.add(MatchRow::new(Const::lit("11"), [n2]));
1067        m.add(MatchRow::new(Const::lit("XX"), [n3]));
1068
1069        let d = h.br(v[1], h.br(v[0], h.rs(n2), h.rs(n3)), h.rs(n1));
1070
1071        assert_dispatch!(m, d);
1072    }
1073
1074    #[test]
1075    fn test_decide_00_10_XX() {
1076        let h = Helper::new();
1077
1078        let v = h.val(2);
1079        let (n1, n2) = (h.net(), h.net());
1080        let mut m = MatchMatrix::new(&v);
1081        m.add(MatchRow::new(Const::lit("00"), [n1]));
1082        m.add(MatchRow::new(Const::lit("01"), [n1]));
1083        m.add(MatchRow::new(Const::lit("XX"), [n2]));
1084
1085        let d = h.br(v[1], h.rs(n2), h.rs(n1));
1086
1087        assert_dispatch!(m, d);
1088    }
1089
1090    #[test]
1091    fn test_match_tree_build_root_1() {
1092        let mut design = Design::new();
1093        let a = design.add_input("a", 1);
1094        let y = design.add_match(MatchCell { value: a, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1095        design.apply();
1096
1097        let y_cell = design.find_cell(y[0]).unwrap().0;
1098
1099        let match_trees = MatchTrees::build(&design);
1100        assert!(match_trees.roots == BTreeSet::from_iter([y_cell]));
1101        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1102    }
1103
1104    #[test]
1105    fn test_match_tree_build_root_pi() {
1106        let mut design = Design::new();
1107        let a = design.add_input("a", 1);
1108        let b = design.add_input1("b");
1109        let y = design.add_match(MatchCell { value: a, enable: b, patterns: vec![vec![Const::lit("0")]] });
1110        design.apply();
1111
1112        let y_cell = design.find_cell(y[0]).unwrap().0;
1113
1114        let match_trees = MatchTrees::build(&design);
1115        assert!(match_trees.roots == BTreeSet::from_iter([y_cell]));
1116        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1117    }
1118
1119    #[test]
1120    fn test_match_tree_build_root_subtree() {
1121        let mut design = Design::new();
1122        let a = design.add_input("a", 1);
1123        let y1 =
1124            design.add_match(MatchCell { value: a.clone(), enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1125        let y2 = design.add_match(MatchCell { value: a, enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1126        design.apply();
1127
1128        let y1_cell = design.find_cell(y1[0]).unwrap().0;
1129        let y2_cell = design.find_cell(y2[0]).unwrap().0;
1130
1131        let match_trees = MatchTrees::build(&design);
1132        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell]));
1133        assert!(match_trees.subtrees == BTreeMap::from_iter([((y1_cell, 0), y2_cell)]));
1134    }
1135
1136    #[test]
1137    fn test_match_tree_build_root_subtrees_disjoint() {
1138        let mut design = Design::new();
1139        let a = design.add_input("a", 1);
1140        let y1 = design.add_match(MatchCell {
1141            value: a.clone(),
1142            enable: Net::ONE,
1143            patterns: vec![vec![Const::lit("0")], vec![Const::lit("1")]],
1144        });
1145        let y2 = design.add_match(MatchCell { value: a.clone(), enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1146        let y3 = design.add_match(MatchCell { value: a, enable: y1[1], patterns: vec![vec![Const::lit("0")]] });
1147        design.apply();
1148
1149        let y1_cell = design.find_cell(y1[0]).unwrap().0;
1150        let y2_cell = design.find_cell(y2[0]).unwrap().0;
1151        let y3_cell = design.find_cell(y3[0]).unwrap().0;
1152
1153        let match_trees = MatchTrees::build(&design);
1154        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell]));
1155        assert!(match_trees.subtrees == BTreeMap::from_iter([((y1_cell, 0), y2_cell), ((y1_cell, 1), y3_cell)]));
1156    }
1157
1158    #[test]
1159    fn test_match_tree_build_root_subtrees_rerooted() {
1160        let mut design = Design::new();
1161        let a = design.add_input("a", 1);
1162        let y1 =
1163            design.add_match(MatchCell { value: a.clone(), enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1164        let y2 = design.add_match(MatchCell { value: a.clone(), enable: y1[0], patterns: vec![vec![Const::lit("0")]] });
1165        let y3 = design.add_match(MatchCell { value: a, enable: y1[0], patterns: vec![vec![Const::lit("1")]] });
1166        design.apply();
1167
1168        let y1_cell = design.find_cell(y1[0]).unwrap().0;
1169        let y2_cell = design.find_cell(y2[0]).unwrap().0;
1170        let y3_cell = design.find_cell(y3[0]).unwrap().0;
1171
1172        let match_trees = MatchTrees::build(&design);
1173        assert!(match_trees.roots == BTreeSet::from_iter([y1_cell, y2_cell, y3_cell]));
1174        assert!(match_trees.subtrees == BTreeMap::from_iter([]));
1175    }
1176
1177    #[test]
1178    fn test_match_cell_into_matrix_flat() {
1179        let mut design = Design::new();
1180        let a = design.add_input("a", 3);
1181        let y = design.add_match(MatchCell {
1182            value: a.clone(),
1183            enable: Net::ONE,
1184            patterns: vec![vec![Const::lit("000"), Const::lit("111")], vec![Const::lit("010")]],
1185        });
1186        let yy = design.add_buf(&y);
1187        design.add_output("y", &yy);
1188        design.apply();
1189
1190        let y_cell = design.find_cell(y[0]).unwrap().0;
1191        let mut match_cells = Vec::new();
1192        let m = MatchTrees::build(&design).cell_into_matrix(y_cell, &mut match_cells);
1193        assert_eq!(match_cells.len(), 1);
1194        assert_eq!(match_cells[0].output(), y);
1195        design.apply();
1196
1197        let yy_cell = design.find_cell(yy[0]).unwrap().0;
1198        let Cell::Buf(y) = &*yy_cell.get() else { unreachable!() };
1199        assert_eq!(m.value, a);
1200        assert_eq!(m.rows, vec![
1201            MatchRow::new(Const::lit("000"), [y[0]]),
1202            MatchRow::new(Const::lit("111"), [y[0]]),
1203            MatchRow::new(Const::lit("010"), [y[1]]),
1204            MatchRow::new(Const::lit("XXX"), []),
1205        ]);
1206    }
1207
1208    #[test]
1209    fn test_match_cell_into_matrix_nested() {
1210        let mut design = Design::new();
1211        let a = design.add_input("a", 3);
1212        let b = design.add_input("b", 2);
1213        let ya = design.add_match(MatchCell {
1214            value: a.clone(),
1215            enable: Net::ONE,
1216            patterns: vec![vec![Const::lit("000"), Const::lit("111")], vec![Const::lit("010")]],
1217        });
1218        let yb = design.add_match(MatchCell {
1219            value: b.clone(),
1220            enable: ya[1],
1221            patterns: vec![vec![Const::lit("00")], vec![Const::lit("11")]],
1222        });
1223        let yya = design.add_buf(&ya);
1224        let yyb = design.add_buf(&yb);
1225        design.add_output("ya", &yya);
1226        design.add_output("yb", &yyb);
1227        design.apply();
1228
1229        let ya_cell = design.find_cell(ya[0]).unwrap().0;
1230        let mut match_cells = Vec::new();
1231        let ml = MatchTrees::build(&design).cell_into_matrix(ya_cell, &mut match_cells);
1232        assert_eq!(match_cells.len(), 2);
1233        assert_eq!(match_cells[0].output(), ya);
1234        assert_eq!(match_cells[1].output(), yb);
1235        design.apply();
1236
1237        let ya_cell = design.find_cell(yya[0]).unwrap().0;
1238        let yb_cell = design.find_cell(yyb[0]).unwrap().0;
1239
1240        let Cell::Buf(ya) = &*ya_cell.get() else { unreachable!() };
1241        let Cell::Buf(yb) = &*yb_cell.get() else { unreachable!() };
1242        let mut mr = MatchMatrix::new(a.concat(b));
1243        mr.add(MatchRow::new(Const::lit("XX000"), [ya[0]]));
1244        mr.add(MatchRow::new(Const::lit("XX111"), [ya[0]]));
1245        mr.add(MatchRow::new(Const::lit("00010"), [ya[1], yb[0]]));
1246        mr.add(MatchRow::new(Const::lit("11010"), [ya[1], yb[1]]));
1247        mr.add(MatchRow::new(Const::lit("XX010"), [ya[1]]));
1248        mr.add(MatchRow::new(Const::lit("XXXXX"), []));
1249
1250        assert_eq!(ml, mr, "\n{ml} != \n{mr}");
1251    }
1252
1253    fn assign(value: impl Into<Value>, enable: impl Into<Net>, update: impl Into<Value>) -> AssignCell {
1254        AssignCell { value: value.into(), enable: enable.into(), update: update.into(), offset: 0 }
1255    }
1256
1257    #[test]
1258    fn test_assign_chains_build_1() {
1259        let mut design = Design::new();
1260        let x = design.add_input("x", 4);
1261        let _a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1262        design.apply();
1263
1264        let AssignChains { chains } = AssignChains::build(&design);
1265
1266        assert!(chains.is_empty());
1267    }
1268
1269    #[test]
1270    fn test_assign_chains_build_2() {
1271        let mut design = Design::new();
1272        let x = design.add_input("x", 4);
1273        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1274        let y = design.add_input("y", 4);
1275        let a2 = design.add_assign(assign(a1.clone(), Net::ONE, y));
1276        design.apply();
1277
1278        let a1_cell = design.find_cell(a1[0]).unwrap().0;
1279        let a2_cell = design.find_cell(a2[0]).unwrap().0;
1280        let AssignChains { chains } = AssignChains::build(&design);
1281
1282        assert!(chains == &[vec![a1_cell, a2_cell]]);
1283    }
1284
1285    #[test]
1286    fn test_assign_chains_build_3_fork() {
1287        let mut design = Design::new();
1288        let x = design.add_input("x", 4);
1289        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1290        let y = design.add_input("y", 4);
1291        let _a2 = design.add_assign(assign(a1.clone(), Net::ONE, y));
1292        let z = design.add_input("z", 4);
1293        let _a3 = design.add_assign(assign(a1.clone(), Net::ONE, z));
1294        design.apply();
1295
1296        let AssignChains { chains } = AssignChains::build(&design);
1297
1298        assert!(chains.is_empty());
1299    }
1300
1301    #[test]
1302    fn test_assign_chains_build_partial_update() {
1303        let mut design = Design::new();
1304        let x = design.add_input("x", 4);
1305        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1306        let y = design.add_input("y", 3);
1307        let _a2 = design.add_assign(AssignCell { value: a1.clone(), enable: Net::ONE, update: y, offset: 1 });
1308        design.apply();
1309
1310        let AssignChains { chains } = AssignChains::build(&design);
1311
1312        assert!(chains.is_empty());
1313    }
1314
1315    #[test]
1316    fn test_assign_chains_build_partial_value() {
1317        let mut design = Design::new();
1318        let x = design.add_input("x", 4);
1319        let a1 = design.add_assign(assign(Value::zero(4), Net::ONE, x));
1320        let y = design.add_input("y", 3);
1321        let _a2 = design.add_assign(assign(a1.slice(..3), Net::ONE, y));
1322        design.apply();
1323
1324        let AssignChains { chains } = AssignChains::build(&design);
1325
1326        assert!(chains.is_empty());
1327    }
1328
1329    #[test]
1330    fn test_assign_lower_disjoint() {
1331        let mut dl = Design::new();
1332        let c = dl.add_input("c", 2);
1333        let m = dl.add_match(MatchCell {
1334            value: c.clone(),
1335            enable: Net::ONE,
1336            patterns: vec![
1337                vec![
1338                    Const::lit("00"), // x1
1339                    Const::lit("11"), // x1
1340                ],
1341                vec![Const::lit("01")], // x2
1342                vec![Const::lit("10")], // x3
1343            ],
1344        });
1345        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1346        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1347        let a3 = dl.add_assign(assign(a2, m[2], dl.add_input("x3", 4)));
1348        dl.add_output("y", a3);
1349        dl.apply();
1350
1351        decision(&mut dl);
1352
1353        let mut dr = Design::new();
1354        let c = dr.add_input("c", 2);
1355        let x1 = dr.add_input("x1", 4);
1356        let x2 = dr.add_input("x2", 4);
1357        let x3 = dr.add_input("x3", 4);
1358        let m1 = dr.add_mux(c[1], &x3, &x1);
1359        let m2 = dr.add_mux(c[1], &x1, &x2);
1360        let m3 = dr.add_mux(c[0], m2, m1);
1361        dr.add_output("y", m3);
1362
1363        assert_isomorphic!(dl, dr);
1364    }
1365
1366    #[test]
1367    fn test_assign_lower_disjoint_partial() {
1368        let mut dl = Design::new();
1369        let c = dl.add_input("c", 2);
1370        let m = dl.add_match(MatchCell {
1371            value: c.clone(),
1372            enable: Net::ONE,
1373            patterns: vec![
1374                vec![
1375                    Const::lit("00"), // x1
1376                    Const::lit("11"), // x1
1377                ],
1378                vec![Const::lit("01")], // x2
1379                vec![Const::lit("10")], // x3
1380            ],
1381        });
1382        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1383        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1384        let a3 = dl.add_assign(assign(a2, m[2], dl.add_input("x3", 3)));
1385        dl.add_output("y", a3);
1386        dl.apply();
1387
1388        decision(&mut dl);
1389        // the particular output generated here is uninteresting, assert that
1390        // lowering doesn't panic and is accepted by SMT
1391    }
1392
1393    #[test]
1394    fn test_assign_lower_disjoint_child() {
1395        let mut dl = Design::new();
1396        let c1 = dl.add_input("c1", 1);
1397        let m1 = dl.add_match(MatchCell {
1398            value: c1,
1399            enable: Net::ONE,
1400            patterns: vec![
1401                vec![Const::lit("0")], // m2
1402                vec![Const::lit("1")], // x2
1403            ],
1404        });
1405
1406        let c2 = dl.add_input("c2", 2);
1407        let m2 = dl.add_match(MatchCell {
1408            value: c2,
1409            enable: m1[0],
1410            patterns: vec![
1411                vec![Const::lit("01")], // x1
1412            ],
1413        });
1414
1415        let a1 = dl.add_assign(assign(Value::zero(4), m2[0], dl.add_input("x1", 4)));
1416        let a2 = dl.add_assign(assign(a1, m1[1], dl.add_input("x2", 4)));
1417        dl.add_output("y", a2);
1418        dl.apply();
1419
1420        decision(&mut dl);
1421
1422        let mut dr = Design::new();
1423        let c1 = dr.add_input("c1", 1);
1424        let c2 = dr.add_input("c2", 2);
1425        let x1 = dr.add_input("x1", 4);
1426        let x2 = dr.add_input("x2", 4);
1427        let m1 = dr.add_mux(c2[1], Value::zero(4), &x1);
1428        let m2 = dr.add_mux(c2[0], &m1, Value::zero(4));
1429        let m3 = dr.add_mux(c1[0], &x2, &m2);
1430        dr.add_output("y", m3);
1431
1432        assert_isomorphic!(dl, dr);
1433    }
1434
1435    #[test]
1436    fn test_assign_lower_overlapping() {
1437        let mut dl = Design::new();
1438        let c = dl.add_input("c", 1);
1439        let m = dl.add_match(MatchCell {
1440            value: c.clone(),
1441            enable: Net::ONE,
1442            patterns: vec![vec![Const::lit("0")], vec![Const::lit("1")]],
1443        });
1444        let a1 = dl.add_assign(assign(Value::zero(4), m[0], dl.add_input("x1", 4)));
1445        let a2 = dl.add_assign(assign(a1, m[1], dl.add_input("x2", 4)));
1446        let a3 = dl.add_assign(assign(a2, m[1], dl.add_input("x3", 4)));
1447        dl.add_output("y", a3);
1448        dl.apply();
1449
1450        decision(&mut dl);
1451
1452        let mut dr = Design::new();
1453        let c = dr.add_input1("c");
1454        let x1 = dr.add_input("x1", 4);
1455        let x2 = dr.add_input("x2", 4);
1456        let x3 = dr.add_input("x3", 4);
1457        let mc = dr.add_mux(c, Const::lit("10"), Const::lit("01"));
1458        let m2 = dr.add_mux(c, x2, x1);
1459        let m3 = dr.add_mux(mc[1], x3, m2);
1460        dr.add_output("y", m3);
1461
1462        assert_isomorphic!(dl, dr);
1463    }
1464
1465    #[test]
1466    fn test_assign_lower_different_matches() {
1467        let mut dl = Design::new();
1468        let c1 = dl.add_input("c1", 1);
1469        let c2 = dl.add_input("c2", 1);
1470        let m1 = dl.add_match(MatchCell { value: c1, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1471        let m2 = dl.add_match(MatchCell { value: c2, enable: Net::ONE, patterns: vec![vec![Const::lit("0")]] });
1472        let a1 = dl.add_assign(assign(Value::zero(4), m1[0], dl.add_input("x1", 4)));
1473        let a2 = dl.add_assign(assign(a1, m2[0], dl.add_input("x2", 4)));
1474        dl.add_output("y", a2);
1475        dl.apply();
1476
1477        decision(&mut dl);
1478
1479        let mut dr = Design::new();
1480        let c1 = dr.add_input1("c1");
1481        let c2 = dr.add_input1("c2");
1482        let mc2 = dr.add_mux(c2, Const::lit("0"), Const::lit("1"));
1483        let m1 = dr.add_mux(c1, Value::zero(4), dr.add_input("x1", 4));
1484        let m2 = dr.add_mux(mc2[0], dr.add_input("x2", 4), m1);
1485        dr.add_output("y", m2);
1486
1487        assert_isomorphic!(dl, dr);
1488    }
1489
1490    #[test]
1491    fn test_assign_lower_partial() {
1492        let mut dl = Design::new();
1493        let en = dl.add_input1("en");
1494        let assign = dl.add_assign(AssignCell {
1495            value: dl.add_input("value", 6),
1496            enable: en,
1497            update: dl.add_input("update", 3),
1498            offset: 2,
1499        });
1500        dl.add_output("assign", assign);
1501        dl.apply();
1502
1503        decision(&mut dl);
1504
1505        let mut dr = Design::new();
1506        let en = dr.add_input1("en");
1507        let value = dr.add_input("value", 6);
1508        let update = dr.add_input("update", 3);
1509        let mux = dr.add_mux(en, update, value.slice(2..5));
1510        dr.add_output("assign", value.slice(..2).concat(mux.concat(value.slice(5..))));
1511
1512        assert_isomorphic!(dl, dr);
1513    }
1514
1515    #[test]
1516    fn test_match_eq_refinement() {
1517        let mut design = Design::new();
1518        let a = design.add_input("a", 2);
1519        let m = design.add_match(MatchCell {
1520            value: a,
1521            enable: Net::ONE,
1522            patterns: vec![vec![Const::lit("00")], vec![Const::lit("XX")]],
1523        });
1524        design.add_output("y", m);
1525        design.apply();
1526
1527        decision(&mut design);
1528        // the particular output generated here is uninteresting, assert that
1529        // lowering doesn't panic and is accepted by SMT
1530    }
1531}