1use std::{borrow::Cow, collections::HashMap};
2
3use prjunnamed_netlist::{Cell, Design, Value};
4
5struct Numberer(HashMap<Cell, Value>);
6
7impl Numberer {
8 fn new() -> Self {
9 Numberer(HashMap::new())
10 }
11
12 fn find_or_insert<'a>(&mut self, cell: Cell, value: impl Into<Cow<'a, Value>>) -> Value {
13 self.0.entry(cell).or_insert_with(|| value.into().into_owned()).clone()
14 }
15
16 fn bitwise_unary<F>(&mut self, rebuild: F, arg: Value, out: &Value) -> Value
17 where
18 F: Fn(Value) -> Cell,
19 {
20 let mut result = Value::new();
21 for (out_net, arg_net) in out.iter().zip(arg.iter()) {
22 let bit_cell = rebuild(Value::from(arg_net));
23 result.extend(self.find_or_insert(bit_cell, out_net));
24 }
25 result
26 }
27
28 fn commutative_bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
29 where
30 F: Fn(Value, Value) -> Cell,
31 {
32 let mut result = Value::new();
33 for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
34 let (arg1_net, arg2_net) = if arg1_net <= arg2_net { (arg1_net, arg2_net) } else { (arg2_net, arg1_net) };
35 let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
36 result.extend(self.find_or_insert(bit_cell, out_net));
37 }
38 result
39 }
40
41 fn bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
42 where
43 F: Fn(Value, Value) -> Cell,
44 {
45 let mut result = Value::new();
46 for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
47 let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
48 result.extend(self.find_or_insert(bit_cell, out_net));
49 }
50 result
51 }
52
53 fn commutative_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
54 where
55 F: Fn(Value, Value) -> Cell,
56 {
57 let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
58 let cell = rebuild(arg1, arg2);
59 self.find_or_insert(cell, out)
60 }
61}
62
63pub fn merge(design: &mut Design) -> bool {
64 let mut numberer = Numberer::new();
65 for cell_ref in design.iter_cells_topo().filter(|cell_ref| !cell_ref.get().has_effects(design)) {
66 let mut cell = cell_ref.get().into_owned();
67 cell.visit_mut(|net| *net = design.map_net(*net));
68 let output = cell_ref.output();
69 let canon = match cell {
70 Cell::Buf(arg) => numberer.bitwise_unary(Cell::Buf, arg, &output),
71 Cell::Not(arg) => numberer.bitwise_unary(Cell::Not, arg, &output),
72 Cell::And(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::And, arg1, arg2, &output),
73 Cell::Or(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Or, arg1, arg2, &output),
74 Cell::Xor(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Xor, arg1, arg2, &output),
75 Cell::Mux(arg1, arg2, arg3) => {
76 numberer.bitwise_binary(|arg2, arg3| Cell::Mux(arg1, arg2, arg3), arg2, arg3, &output)
77 }
78 Cell::Adc(arg1, arg2, arg3) => {
79 numberer.commutative_binary(|arg1, arg2| Cell::Adc(arg1, arg2, arg3), arg1, arg2, &output)
80 }
81 Cell::Eq(arg1, arg2) => numberer.commutative_binary(Cell::Eq, arg1, arg2, &output),
82 Cell::Mul(arg1, arg2) => numberer.commutative_binary(Cell::Mul, arg1, arg2, &output),
83 Cell::Aig(arg1, arg2) => {
84 let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
85 let cell = Cell::Aig(arg1, arg2);
86 numberer.find_or_insert(cell, &output)
87 }
88 _ => numberer.find_or_insert(cell, &output),
89 };
90 if cfg!(feature = "trace") {
91 if output != canon {
92 eprintln!(">merge {} => {}", design.display_value(&output), design.display_value(&canon));
93 }
94 }
95 for canon_net in canon.iter() {
96 let Ok((canon_cell_ref, _offset)) = design.find_cell(canon_net) else { unreachable!() };
97 canon_cell_ref.append_metadata(cell_ref.metadata());
98 }
99 design.replace_value(output, canon);
100 }
101 design.compact()
102}
103
104#[cfg(test)]
105mod test {
106 use prjunnamed_netlist::{assert_isomorphic, Design, Value};
107
108 use crate::merge::merge;
109
110 #[test]
111 fn test_merge_commutative_xor() {
112 let mut dl = Design::new();
113 let a = dl.add_input("a", 2);
114 let b = dl.add_input("b", 2);
115 let x1 = dl.add_xor1(a[0], b[0]);
116 let x2 = dl.add_xor1(a[1], b[1]);
117 let x3 = dl.add_xor(b, a);
118 dl.add_output("y", Value::from(x1).concat(x2).concat(x3));
119 dl.apply();
120 merge(&mut dl);
121
122 let mut dr = Design::new();
123 let a = dr.add_input("a", 2);
124 let b = dr.add_input("b", 2);
125 let x1 = dr.add_xor1(a[0], b[0]);
126 let x2 = dr.add_xor1(a[1], b[1]);
127 dr.add_output("y", Value::from(x1).concat(x2).concat(x1).concat(x2));
128 dr.apply();
129
130 assert_isomorphic!(dl, dr);
131 }
132}