1use std::{cell::RefCell, cmp::Ordering, collections::HashMap};
2
3use prjunnamed_netlist::{
4 Cell, ControlNet, Design, MetaItemRef, Net, RewriteNetSource, RewriteResult, RewriteRuleset, Rewriter, Value,
5};
6
7use crate::{LevelAnalysis, Normalize, SimpleAigOpt};
8
9#[derive(Clone, Debug)]
10struct AigChain {
11 invert: bool,
12 min_level: u32,
13 full_trees: Vec<AigFullTree>,
20}
21
22#[derive(Copy, Clone, Debug)]
23struct PropGen {
24 p: ControlNet,
25 g: ControlNet,
26}
27
28impl PropGen {
29 fn or(net: ControlNet) -> Self {
30 PropGen { p: ControlNet::Pos(Net::ONE), g: net }
31 }
32
33 fn and(net: ControlNet) -> Self {
34 PropGen { p: net, g: ControlNet::Pos(Net::ZERO) }
35 }
36
37 fn combine(rewriter: &Rewriter, a: PropGen, b: PropGen) -> PropGen {
38 let prop_val = rewriter.add_cell(Cell::Aig(a.p, b.p));
39 let tmp = rewriter.add_cell(Cell::Aig(a.g, b.p));
40 let genr_val_b = rewriter.add_cell(Cell::Aig(ControlNet::Neg(tmp[0]), !b.g));
41 PropGen { p: ControlNet::Pos(prop_val[0]), g: ControlNet::Neg(genr_val_b[0]) }
42 }
43}
44
45#[derive(Copy, Clone, Debug)]
46struct AigFullTree {
47 level: u32,
48 pg: PropGen,
49 cumulative: PropGen,
50}
51
52#[derive(Clone, Debug)]
53struct XorChain {
54 min_level: u32,
55 full_trees: Vec<XorFullTree>,
56}
57
58#[derive(Copy, Clone, Debug)]
59struct XorFullTree {
60 level: u32,
61 net: Net,
62 cumulative_net: Net,
63}
64
65pub struct ChainRebalance<'a> {
66 levels: &'a LevelAnalysis,
67 aig_chains: RefCell<HashMap<Net, AigChain>>,
68 xor_chains: RefCell<HashMap<Net, XorChain>>,
69}
70
71impl<'a> ChainRebalance<'a> {
72 pub fn new(levels: &'a LevelAnalysis) -> Self {
73 Self { levels, aig_chains: Default::default(), xor_chains: Default::default() }
74 }
75}
76
77impl RewriteRuleset for ChainRebalance<'_> {
78 fn rewrite<'a>(
79 &self,
80 cell: &Cell,
81 _meta: MetaItemRef<'a>,
82 output: Option<&Value>,
83 rewriter: &Rewriter<'a>,
84 ) -> RewriteResult<'a> {
85 let Some(output) = output else { return RewriteResult::None };
86 if output.len() != 1 {
87 return RewriteResult::None;
88 }
89 let output = output[0];
90 match cell {
91 &Cell::Aig(net1, net2) => {
92 let level1 = self.levels.get(net1.net());
93 let level2 = self.levels.get(net2.net());
94 let (net_a, net_b, level_a, level_b) = match level1.cmp(&level2) {
95 Ordering::Less => (net2, net1, level2, level1),
96 Ordering::Equal => return RewriteResult::None,
97 Ordering::Greater => (net1, net2, level1, level2),
98 };
99 let mut aig_chains = self.aig_chains.borrow_mut();
100 if let Some(chain) = aig_chains.get(&net_a.net()) {
101 let mut chain = chain.clone();
102 if net_a.is_negative() {
103 chain.invert = !chain.invert;
104 }
105 chain.min_level = chain.min_level.max(level_b);
106 let mut top = chain.full_trees.pop().unwrap();
108 while top.level < chain.min_level {
109 if let Some(&next_top) = chain.full_trees.last()
110 && next_top.level <= chain.min_level
111 {
112 top.level = next_top.level + 1;
113 top.pg = next_top.cumulative;
114 top.cumulative = next_top.cumulative;
115 chain.full_trees.pop();
116 } else {
117 top.level = chain.min_level;
118 break;
119 }
120 }
121 chain.full_trees.push(top);
122 let pg = if chain.invert { PropGen::or(!net_b) } else { PropGen::and(net_b) };
124 let mut new_top = AigFullTree { level: chain.min_level, pg, cumulative: pg };
125 while let Some(&cur_top) = chain.full_trees.last()
126 && cur_top.level == new_top.level
127 {
128 chain.full_trees.pop();
129 new_top.pg = PropGen::combine(rewriter, cur_top.pg, new_top.pg);
130 new_top.cumulative = new_top.pg;
131 new_top.level += 1;
132 }
133 let mut cumulative = new_top.pg;
135 for subtree in chain.full_trees.iter_mut().rev() {
136 cumulative = PropGen::combine(rewriter, subtree.pg, cumulative);
137 subtree.cumulative = cumulative;
138 }
139 chain.full_trees.push(new_top);
141 let mut result = rewriter.add_cell(Cell::Aig(!cumulative.p, !cumulative.g))[0];
142 if !chain.invert {
143 result = rewriter.add_cell(Cell::Not(result.into()))[0];
144 }
145 if let RewriteNetSource::Cell(cell, _, _) = rewriter.find_cell(result)
146 && let Cell::Not(ref inv_result) = *cell
147 {
148 let inv_result = inv_result[0];
149 chain.invert = !chain.invert;
150 aig_chains.insert(inv_result, chain);
151 } else {
152 aig_chains.insert(result, chain);
153 }
154 return result.into();
155 } else {
156 if net_a.is_negative() {
157 let chain = AigChain {
158 invert: true,
159 min_level: level_a - 1,
160 full_trees: vec![
161 AigFullTree {
162 level: level_a,
163 pg: PropGen::and(!net_a),
164 cumulative: PropGen { p: !net_a, g: !net_b },
165 },
166 AigFullTree {
167 level: level_a - 1,
168 pg: PropGen::or(!net_b),
169 cumulative: PropGen::or(!net_b),
170 },
171 ],
172 };
173 aig_chains.insert(output, chain);
174 } else {
175 let chain = AigChain {
176 invert: false,
177 min_level: level_a - 1,
178 full_trees: vec![
179 AigFullTree {
180 level: level_a,
181 pg: PropGen::and(net_a),
182 cumulative: PropGen::and(output.into()),
183 },
184 AigFullTree {
185 level: level_a - 1,
186 pg: PropGen::and(net_b),
187 cumulative: PropGen::and(net_b),
188 },
189 ],
190 };
191 aig_chains.insert(output, chain);
192 }
193 RewriteResult::None
194 }
195 }
196 Cell::Xor(val1, val2) if val1.len() == 1 => {
197 let net1 = val1[0];
198 let net2 = val2[0];
199 let level1 = self.levels.get(net1);
200 let level2 = self.levels.get(net2);
201 let (net_a, net_b, level_a, level_b) = match level1.cmp(&level2) {
202 Ordering::Less => (net2, net1, level2, level1),
203 Ordering::Equal => return RewriteResult::None,
204 Ordering::Greater => (net1, net2, level1, level2),
205 };
206 let mut xor_chains = self.xor_chains.borrow_mut();
207 if let Some(chain) = xor_chains.get(&net_a) {
208 let mut chain = chain.clone();
209 chain.min_level = chain.min_level.max(level_b);
210 if chain.full_trees.len() == 1 {
211 if chain.full_trees[0].level > level_b {
212 chain.full_trees[0].cumulative_net = output;
213 chain.full_trees.push(XorFullTree { level: level_b, net: net_b, cumulative_net: net_b });
214 xor_chains.insert(output, chain);
215 }
216 return RewriteResult::None;
217 }
218 let mut top = chain.full_trees.pop().unwrap();
220 while top.level < chain.min_level {
221 if let Some(&next_top) = chain.full_trees.last()
222 && next_top.level <= chain.min_level
223 {
224 top.level = next_top.level + 1;
225 top.net = next_top.cumulative_net;
226 top.cumulative_net = next_top.cumulative_net;
227 chain.full_trees.pop();
228 } else {
229 top.level = chain.min_level;
230 break;
231 }
232 }
233 chain.full_trees.push(top);
234 let mut level_top = chain.min_level;
236 let mut net_top = net_b;
237 while let Some(&next_top) = chain.full_trees.last()
238 && next_top.level == level_top
239 {
240 chain.full_trees.pop();
241 let val = rewriter.add_cell(Cell::Xor(net_top.into(), next_top.net.into()));
242 net_top = val[0];
243 level_top += 1;
244 }
245 let mut cumulative_net = net_top;
247 for subtree in chain.full_trees.iter_mut().rev() {
248 let val = rewriter.add_cell(Cell::Xor(cumulative_net.into(), subtree.net.into()));
249 cumulative_net = val[0];
250 subtree.cumulative_net = cumulative_net;
251 }
252 chain.full_trees.push(XorFullTree { level: level_top, net: net_top, cumulative_net: net_top });
254 xor_chains.insert(cumulative_net, chain);
255 return cumulative_net.into();
256 } else {
257 let chain = XorChain {
258 min_level: level_a - 1,
259 full_trees: vec![
260 XorFullTree { level: level_a, net: net_a, cumulative_net: output },
261 XorFullTree { level: level_a - 1, net: net_b, cumulative_net: net_b },
262 ],
263 };
264 xor_chains.insert(output, chain);
265 RewriteResult::None
266 }
267 }
268 _ => RewriteResult::None,
269 }
270 }
271
272 fn net_replaced(&self, _design: &Design, from: Net, to: Net) {
273 let mut aig_chains = self.aig_chains.borrow_mut();
274 if let Some(chain) = aig_chains.get(&from)
275 && !aig_chains.contains_key(&to)
276 {
277 let chain = chain.clone();
278 aig_chains.insert(to, chain);
279 }
280 let mut xor_chains = self.xor_chains.borrow_mut();
281 if let Some(chain) = xor_chains.get(&from)
282 && !xor_chains.contains_key(&to)
283 {
284 let chain = chain.clone();
285 xor_chains.insert(to, chain);
286 }
287 }
288}
289
290pub fn chain_rebalance(design: &mut Design) {
291 let levels = LevelAnalysis::new();
292 let rebalance = ChainRebalance::new(&levels);
293 design.rewrite(&[&Normalize, &SimpleAigOpt, &levels, &rebalance]);
294}