prjunnamed_generic/
merge.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use prjunnamed_netlist::{Cell, Design, Value};
4
5struct Numberer(HashMap<Cell, Value>);
6
7impl Numberer {
8    fn new() -> Self {
9        Numberer(HashMap::new())
10    }
11
12    fn find_or_insert<'a>(&mut self, cell: Cell, value: impl Into<Cow<'a, Value>>) -> Value {
13        self.0.entry(cell).or_insert_with(|| value.into().into_owned()).clone()
14    }
15
16    fn bitwise_unary<F>(&mut self, rebuild: F, arg: Value, out: &Value) -> Value
17    where
18        F: Fn(Value) -> Cell,
19    {
20        let mut result = Value::new();
21        for (out_net, arg_net) in out.iter().zip(arg.iter()) {
22            let bit_cell = rebuild(Value::from(arg_net));
23            result.extend(self.find_or_insert(bit_cell, out_net));
24        }
25        result
26    }
27
28    fn commutative_bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
29    where
30        F: Fn(Value, Value) -> Cell,
31    {
32        let mut result = Value::new();
33        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
34            let (arg1_net, arg2_net) = if arg1_net <= arg2_net { (arg1_net, arg2_net) } else { (arg2_net, arg1_net) };
35            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
36            result.extend(self.find_or_insert(bit_cell, out_net));
37        }
38        result
39    }
40
41    fn bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
42    where
43        F: Fn(Value, Value) -> Cell,
44    {
45        let mut result = Value::new();
46        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
47            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
48            result.extend(self.find_or_insert(bit_cell, out_net));
49        }
50        result
51    }
52
53    fn commutative_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
54    where
55        F: Fn(Value, Value) -> Cell,
56    {
57        let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
58        let cell = rebuild(arg1, arg2);
59        self.find_or_insert(cell, out)
60    }
61}
62
63pub fn merge(design: &mut Design) -> bool {
64    let mut numberer = Numberer::new();
65    for cell_ref in design.iter_cells_topo().filter(|cell_ref| !cell_ref.get().has_effects(design)) {
66        let mut cell = cell_ref.get().into_owned();
67        cell.visit_mut(|net| *net = design.map_net(*net));
68        let output = cell_ref.output();
69        let canon = match cell {
70            Cell::Buf(arg) => numberer.bitwise_unary(Cell::Buf, arg, &output),
71            Cell::Not(arg) => numberer.bitwise_unary(Cell::Not, arg, &output),
72            Cell::And(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::And, arg1, arg2, &output),
73            Cell::Or(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Or, arg1, arg2, &output),
74            Cell::Xor(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Xor, arg1, arg2, &output),
75            Cell::Mux(arg1, arg2, arg3) => {
76                numberer.bitwise_binary(|arg2, arg3| Cell::Mux(arg1, arg2, arg3), arg2, arg3, &output)
77            }
78            Cell::Adc(arg1, arg2, arg3) => {
79                numberer.commutative_binary(|arg1, arg2| Cell::Adc(arg1, arg2, arg3), arg1, arg2, &output)
80            }
81            Cell::Eq(arg1, arg2) => numberer.commutative_binary(Cell::Eq, arg1, arg2, &output),
82            Cell::Mul(arg1, arg2) => numberer.commutative_binary(Cell::Mul, arg1, arg2, &output),
83            Cell::Aig(arg1, arg2) => {
84                let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
85                let cell = Cell::Aig(arg1, arg2);
86                numberer.find_or_insert(cell, &output)
87            }
88            _ => numberer.find_or_insert(cell, &output),
89        };
90        if cfg!(feature = "trace") && output != canon {
91            eprintln!(">merge {} => {}", design.display_value(&output), design.display_value(&canon));
92        }
93        for canon_net in canon.iter() {
94            let Ok((canon_cell_ref, _offset)) = design.find_cell(canon_net) else { unreachable!() };
95            canon_cell_ref.append_metadata(cell_ref.metadata());
96        }
97        design.replace_value(output, canon);
98    }
99    design.compact()
100}
101
102#[cfg(test)]
103mod test {
104    use prjunnamed_netlist::{assert_isomorphic, Design, Value};
105
106    use crate::merge::merge;
107
108    #[test]
109    fn test_merge_commutative_xor() {
110        let mut dl = Design::new();
111        let a = dl.add_input("a", 2);
112        let b = dl.add_input("b", 2);
113        let x1 = dl.add_xor1(a[0], b[0]);
114        let x2 = dl.add_xor1(a[1], b[1]);
115        let x3 = dl.add_xor(b, a);
116        dl.add_output("y", Value::from(x1).concat(x2).concat(x3));
117        dl.apply();
118        merge(&mut dl);
119
120        let mut dr = Design::new();
121        let a = dr.add_input("a", 2);
122        let b = dr.add_input("b", 2);
123        let x1 = dr.add_xor1(a[0], b[0]);
124        let x2 = dr.add_xor1(a[1], b[1]);
125        dr.add_output("y", Value::from(x1).concat(x2).concat(x1).concat(x2));
126        dr.apply();
127
128        assert_isomorphic!(dl, dr);
129    }
130}