prjunnamed_generic/
tree_rebalance.rs1use std::{
2 cell::RefCell,
3 collections::{BTreeSet, BinaryHeap, HashMap, HashSet},
4};
5
6use prjunnamed_netlist::{Cell, ControlNet, Design, Net, RewriteResult, RewriteRuleset};
7
8use crate::{LevelAnalysis, Normalize, SimpleAigOpt};
9
10struct TreeRebalance<'a> {
11 levels: &'a LevelAnalysis,
12 inner_aig: HashSet<Net>,
13 inner_xor: HashSet<Net>,
14 aig_trees: RefCell<HashMap<Net, BTreeSet<ControlNet>>>,
15 xor_trees: RefCell<HashMap<Net, BTreeSet<Net>>>,
16}
17
18impl<'a> TreeRebalance<'a> {
19 fn new(design: &Design, levels: &'a LevelAnalysis) -> Self {
20 let mut inner_aig = HashSet::new();
21 let mut inner_xor = HashSet::new();
22 let mut use_count = HashMap::<Net, u32>::new();
23 for cell in design.iter_cells() {
24 cell.visit(|net| {
25 *use_count.entry(net).or_default() += 1;
26 });
27 }
28 for cell in design.iter_cells() {
29 if let Cell::Aig(net1, net2) = *cell.get() {
30 for net in [net1, net2] {
31 if let ControlNet::Pos(net) = net
32 && use_count[&net] == 1
33 {
34 inner_aig.insert(net);
35 }
36 }
37 }
38 if let Cell::Xor(ref val1, ref val2) = *cell.get() {
39 for val in [val1, val2] {
40 for net in val {
41 if use_count[&net] == 1 {
42 inner_xor.insert(net);
43 }
44 }
45 }
46 }
47 }
48 Self { levels, inner_aig, inner_xor, aig_trees: Default::default(), xor_trees: Default::default() }
49 }
50}
51
52impl RewriteRuleset for TreeRebalance<'_> {
53 fn rewrite<'a>(
54 &self,
55 cell: &Cell,
56 _meta: prjunnamed_netlist::MetaItemRef<'a>,
57 output: Option<&prjunnamed_netlist::Value>,
58 rewriter: &prjunnamed_netlist::Rewriter<'a>,
59 ) -> RewriteResult<'a> {
60 let Some(output) = output else {
61 return RewriteResult::None;
62 };
63 if output.len() != 1 {
64 return RewriteResult::None;
65 }
66 let output = output[0];
67 match *cell {
68 Cell::Aig(net1, net2) => {
69 let mut aig_trees = self.aig_trees.borrow_mut();
70 let mut inputs1 = if let ControlNet::Pos(net) = net1
71 && let Some(inputs) = aig_trees.remove(&net)
72 {
73 inputs
74 } else {
75 BTreeSet::from_iter([net1])
76 };
77 let mut inputs2 = if let ControlNet::Pos(net) = net2
78 && let Some(inputs) = aig_trees.remove(&net)
79 {
80 inputs
81 } else {
82 BTreeSet::from_iter([net2])
83 };
84 if inputs1.len() < inputs2.len() {
85 std::mem::swap(&mut inputs1, &mut inputs2);
86 }
87 inputs1.extend(inputs2);
88 let inputs = inputs1;
89 if self.inner_aig.contains(&output) {
90 aig_trees.insert(output, inputs);
91 RewriteResult::None
92 } else {
93 if inputs.len() == 2 {
94 return RewriteResult::None;
95 }
96 let mut inputs = BinaryHeap::from_iter(
97 inputs.into_iter().map(|net| std::cmp::Reverse((self.levels.get(net.net()), net))),
98 );
99 while inputs.len() > 1 {
100 let (lvl1, net1) = inputs.pop().unwrap().0;
101 let (lvl2, net2) = inputs.pop().unwrap().0;
102 let lvl = lvl1.max(lvl2) + 1;
103 let val = rewriter.add_cell(Cell::Aig(net1, net2));
104 let net = ControlNet::Pos(val[0]);
105 inputs.push(std::cmp::Reverse((lvl, net)));
106 }
107 let net = inputs.pop().unwrap().0.1;
108 net.into()
109 }
110 }
111 Cell::Xor(ref val1, ref val2) => {
112 let net1 = val1[0];
113 let net2 = val2[0];
114 let mut xor_trees = self.xor_trees.borrow_mut();
115 let mut inputs1 =
116 if let Some(inputs) = xor_trees.remove(&net1) { inputs } else { BTreeSet::from_iter([net1]) };
117 let mut inputs2 =
118 if let Some(inputs) = xor_trees.remove(&net2) { inputs } else { BTreeSet::from_iter([net2]) };
119 if inputs1.len() < inputs2.len() {
120 std::mem::swap(&mut inputs1, &mut inputs2);
121 }
122 for net in inputs2 {
123 if !inputs1.remove(&net) {
124 inputs1.insert(net);
125 }
126 }
127 let inputs = inputs1;
128 if self.inner_xor.contains(&output) {
129 xor_trees.insert(output, inputs);
130 RewriteResult::None
131 } else {
132 if inputs.len() == 2 {
133 return RewriteResult::None;
134 }
135 let mut inputs = BinaryHeap::from_iter(
136 inputs.into_iter().map(|net| std::cmp::Reverse((self.levels.get(net), net))),
137 );
138 while inputs.len() > 1 {
139 let (lvl1, net1) = inputs.pop().unwrap().0;
140 let (lvl2, net2) = inputs.pop().unwrap().0;
141 let lvl = lvl1.max(lvl2) + 1;
142 let val = rewriter.add_cell(Cell::Xor(net1.into(), net2.into()));
143 inputs.push(std::cmp::Reverse((lvl, val[0])));
144 }
145 let net = inputs.pop().unwrap().0.1;
146 net.into()
147 }
148 }
149 _ => RewriteResult::None,
150 }
151 }
152
153 fn net_replaced(&self, _design: &Design, from: Net, to: Net) {
154 let mut aig_trees = self.aig_trees.borrow_mut();
155 if let Some(tree) = aig_trees.remove(&from) {
156 aig_trees.insert(to, tree);
157 }
158 let mut xor_trees = self.xor_trees.borrow_mut();
159 if let Some(tree) = xor_trees.remove(&from) {
160 xor_trees.insert(to, tree);
161 }
162 }
163}
164
165pub fn tree_rebalance(design: &mut Design) {
166 let levels = LevelAnalysis::new();
167 let rebalance = TreeRebalance::new(&design, &levels);
168 design.rewrite(&[&Normalize, &SimpleAigOpt, &levels, &rebalance]);
169}