prjunnamed_generic/
merge.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use prjunnamed_netlist::{Cell, Design, Value};
4
5struct Numberer(HashMap<Cell, Value>);
6
7impl Numberer {
8    fn new() -> Self {
9        Numberer(HashMap::new())
10    }
11
12    fn find_or_insert<'a>(&mut self, cell: Cell, value: impl Into<Cow<'a, Value>>) -> Value {
13        self.0.entry(cell).or_insert_with(|| value.into().into_owned()).clone()
14    }
15
16    fn bitwise_unary<F>(&mut self, rebuild: F, arg: Value, out: &Value) -> Value
17    where
18        F: Fn(Value) -> Cell,
19    {
20        let mut result = Value::new();
21        for (out_net, arg_net) in out.iter().zip(arg.iter()) {
22            let bit_cell = rebuild(Value::from(arg_net));
23            result.extend(self.find_or_insert(bit_cell, out_net));
24        }
25        result
26    }
27
28    fn commutative_bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
29    where
30        F: Fn(Value, Value) -> Cell,
31    {
32        let mut result = Value::new();
33        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
34            let (arg1_net, arg2_net) = if arg1_net <= arg2_net { (arg1_net, arg2_net) } else { (arg2_net, arg1_net) };
35            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
36            result.extend(self.find_or_insert(bit_cell, out_net));
37        }
38        result
39    }
40
41    fn bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
42    where
43        F: Fn(Value, Value) -> Cell,
44    {
45        let mut result = Value::new();
46        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
47            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
48            result.extend(self.find_or_insert(bit_cell, out_net));
49        }
50        result
51    }
52
53    fn commutative_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
54    where
55        F: Fn(Value, Value) -> Cell,
56    {
57        let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
58        let cell = rebuild(arg1, arg2);
59        self.find_or_insert(cell, out)
60    }
61}
62
63pub fn merge(design: &mut Design) -> bool {
64    let mut numberer = Numberer::new();
65    for cell_ref in design.iter_cells_topo().filter(|cell_ref| !cell_ref.get().has_effects(design)) {
66        let mut cell = cell_ref.get().into_owned();
67        cell.visit_mut(|net| *net = design.map_net(*net));
68        let output = cell_ref.output();
69        let canon = match cell {
70            Cell::Buf(arg) => numberer.bitwise_unary(Cell::Buf, arg, &output),
71            Cell::Not(arg) => numberer.bitwise_unary(Cell::Not, arg, &output),
72            Cell::And(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::And, arg1, arg2, &output),
73            Cell::Or(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Or, arg1, arg2, &output),
74            Cell::Xor(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Xor, arg1, arg2, &output),
75            Cell::Mux(arg1, arg2, arg3) => {
76                numberer.bitwise_binary(|arg2, arg3| Cell::Mux(arg1, arg2, arg3), arg2, arg3, &output)
77            }
78            Cell::Adc(arg1, arg2, arg3) => {
79                numberer.commutative_binary(|arg1, arg2| Cell::Adc(arg1, arg2, arg3), arg1, arg2, &output)
80            }
81            Cell::Eq(arg1, arg2) => numberer.commutative_binary(Cell::Eq, arg1, arg2, &output),
82            Cell::Mul(arg1, arg2) => numberer.commutative_binary(Cell::Mul, arg1, arg2, &output),
83            Cell::Aig(arg1, arg2) => {
84                let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
85                let cell = Cell::Aig(arg1, arg2);
86                numberer.find_or_insert(cell, &output)
87            }
88            _ => numberer.find_or_insert(cell, &output),
89        };
90        if cfg!(feature = "trace") {
91            if output != canon {
92                eprintln!(">merge {} => {}", design.display_value(&output), design.display_value(&canon));
93            }
94        }
95        for canon_net in canon.iter() {
96            let Ok((canon_cell_ref, _offset)) = design.find_cell(canon_net) else { unreachable!() };
97            canon_cell_ref.append_metadata(cell_ref.metadata());
98        }
99        design.replace_value(output, canon);
100    }
101    design.compact()
102}
103
104#[cfg(test)]
105mod test {
106    use prjunnamed_netlist::{assert_isomorphic, Design, Value};
107
108    use crate::merge::merge;
109
110    #[test]
111    fn test_merge_commutative_xor() {
112        let mut dl = Design::new();
113        let a = dl.add_input("a", 2);
114        let b = dl.add_input("b", 2);
115        let x1 = dl.add_xor1(a[0], b[0]);
116        let x2 = dl.add_xor1(a[1], b[1]);
117        let x3 = dl.add_xor(b, a);
118        dl.add_output("y", Value::from(x1).concat(x2).concat(x3));
119        dl.apply();
120        merge(&mut dl);
121
122        let mut dr = Design::new();
123        let a = dr.add_input("a", 2);
124        let b = dr.add_input("b", 2);
125        let x1 = dr.add_xor1(a[0], b[0]);
126        let x2 = dr.add_xor1(a[1], b[1]);
127        dr.add_output("y", Value::from(x1).concat(x2).concat(x1).concat(x2));
128        dr.apply();
129
130        assert_isomorphic!(dl, dr);
131    }
132}