prjunnamed_generic/
chain_rebalance.rs

1use std::{cell::RefCell, cmp::Ordering, collections::HashMap};
2
3use prjunnamed_netlist::{
4    Cell, ControlNet, Design, MetaItemRef, Net, RewriteNetSource, RewriteResult, RewriteRuleset, Rewriter, Value,
5};
6
7use crate::{LevelAnalysis, Normalize, SimpleAigOpt};
8
9#[derive(Clone, Debug)]
10struct AigChain {
11    invert: bool,
12    min_level: u32,
13    /// List of (level, propagate, generate) pairs to be used for further rebalancing.
14    ///
15    /// This list satisfies the following conditions:
16    /// 1. The node is equivalent to an AND-OR of all inputs on this list (in order, starting from const-1).
17    /// 2. The list is sorted strictly descending by level (no two nodes are the same level).
18    /// 3. All prop/genr levels are no smaller than `min_level`.
19    full_trees: Vec<AigFullTree>,
20}
21
22#[derive(Copy, Clone, Debug)]
23struct PropGen {
24    p: ControlNet,
25    g: ControlNet,
26}
27
28impl PropGen {
29    fn or(net: ControlNet) -> Self {
30        PropGen { p: ControlNet::Pos(Net::ONE), g: net }
31    }
32
33    fn and(net: ControlNet) -> Self {
34        PropGen { p: net, g: ControlNet::Pos(Net::ZERO) }
35    }
36
37    fn combine(rewriter: &Rewriter, a: PropGen, b: PropGen) -> PropGen {
38        let prop_val = rewriter.add_cell(Cell::Aig(a.p, b.p));
39        let tmp = rewriter.add_cell(Cell::Aig(a.g, b.p));
40        let genr_val_b = rewriter.add_cell(Cell::Aig(ControlNet::Neg(tmp[0]), !b.g));
41        PropGen { p: ControlNet::Pos(prop_val[0]), g: ControlNet::Neg(genr_val_b[0]) }
42    }
43}
44
45#[derive(Copy, Clone, Debug)]
46struct AigFullTree {
47    level: u32,
48    pg: PropGen,
49    cumulative: PropGen,
50}
51
52#[derive(Clone, Debug)]
53struct XorChain {
54    min_level: u32,
55    full_trees: Vec<XorFullTree>,
56}
57
58#[derive(Copy, Clone, Debug)]
59struct XorFullTree {
60    level: u32,
61    net: Net,
62    cumulative_net: Net,
63}
64
65pub struct ChainRebalance<'a> {
66    levels: &'a LevelAnalysis,
67    aig_chains: RefCell<HashMap<Net, AigChain>>,
68    xor_chains: RefCell<HashMap<Net, XorChain>>,
69}
70
71impl<'a> ChainRebalance<'a> {
72    pub fn new(levels: &'a LevelAnalysis) -> Self {
73        Self { levels, aig_chains: Default::default(), xor_chains: Default::default() }
74    }
75}
76
77impl RewriteRuleset for ChainRebalance<'_> {
78    fn rewrite<'a>(
79        &self,
80        cell: &Cell,
81        _meta: MetaItemRef<'a>,
82        output: Option<&Value>,
83        rewriter: &Rewriter<'a>,
84    ) -> RewriteResult<'a> {
85        let Some(output) = output else { return RewriteResult::None };
86        if output.len() != 1 {
87            return RewriteResult::None;
88        }
89        let output = output[0];
90        match cell {
91            &Cell::Aig(net1, net2) => {
92                let level1 = self.levels.get(net1.net());
93                let level2 = self.levels.get(net2.net());
94                let (net_a, net_b, level_a, level_b) = match level1.cmp(&level2) {
95                    Ordering::Less => (net2, net1, level2, level1),
96                    Ordering::Equal => return RewriteResult::None,
97                    Ordering::Greater => (net1, net2, level1, level2),
98                };
99                let mut aig_chains = self.aig_chains.borrow_mut();
100                if let Some(chain) = aig_chains.get(&net_a.net()) {
101                    let mut chain = chain.clone();
102                    if net_a.is_negative() {
103                        chain.invert = !chain.invert;
104                    }
105                    chain.min_level = chain.min_level.max(level_b);
106                    // adjust levels of everything to at least the new min_level
107                    let mut top = chain.full_trees.pop().unwrap();
108                    while top.level < chain.min_level {
109                        if let Some(&next_top) = chain.full_trees.last()
110                            && next_top.level <= chain.min_level
111                        {
112                            top.level = next_top.level + 1;
113                            top.pg = next_top.cumulative;
114                            top.cumulative = next_top.cumulative;
115                            chain.full_trees.pop();
116                        } else {
117                            top.level = chain.min_level;
118                            break;
119                        }
120                    }
121                    chain.full_trees.push(top);
122                    // add the new input; merge last two entries until invariant holds
123                    let pg = if chain.invert { PropGen::or(!net_b) } else { PropGen::and(net_b) };
124                    let mut new_top = AigFullTree { level: chain.min_level, pg, cumulative: pg };
125                    while let Some(&cur_top) = chain.full_trees.last()
126                        && cur_top.level == new_top.level
127                    {
128                        chain.full_trees.pop();
129                        new_top.pg = PropGen::combine(rewriter, cur_top.pg, new_top.pg);
130                        new_top.cumulative = new_top.pg;
131                        new_top.level += 1;
132                    }
133                    // don't push the new last entry just yet; compute cumulative first
134                    let mut cumulative = new_top.pg;
135                    for subtree in chain.full_trees.iter_mut().rev() {
136                        cumulative = PropGen::combine(rewriter, subtree.pg, cumulative);
137                        subtree.cumulative = cumulative;
138                    }
139                    // now push the new last entry
140                    chain.full_trees.push(new_top);
141                    let mut result = rewriter.add_cell(Cell::Aig(!cumulative.p, !cumulative.g))[0];
142                    if !chain.invert {
143                        result = rewriter.add_cell(Cell::Not(result.into()))[0];
144                    }
145                    if let RewriteNetSource::Cell(cell, _, _) = rewriter.find_cell(result)
146                        && let Cell::Not(ref inv_result) = *cell
147                    {
148                        let inv_result = inv_result[0];
149                        chain.invert = !chain.invert;
150                        aig_chains.insert(inv_result, chain);
151                    } else {
152                        aig_chains.insert(result, chain);
153                    }
154                    return result.into();
155                } else {
156                    if net_a.is_negative() {
157                        let chain = AigChain {
158                            invert: true,
159                            min_level: level_a - 1,
160                            full_trees: vec![
161                                AigFullTree {
162                                    level: level_a,
163                                    pg: PropGen::and(!net_a),
164                                    cumulative: PropGen { p: !net_a, g: !net_b },
165                                },
166                                AigFullTree {
167                                    level: level_a - 1,
168                                    pg: PropGen::or(!net_b),
169                                    cumulative: PropGen::or(!net_b),
170                                },
171                            ],
172                        };
173                        aig_chains.insert(output, chain);
174                    } else {
175                        let chain = AigChain {
176                            invert: false,
177                            min_level: level_a - 1,
178                            full_trees: vec![
179                                AigFullTree {
180                                    level: level_a,
181                                    pg: PropGen::and(net_a),
182                                    cumulative: PropGen::and(output.into()),
183                                },
184                                AigFullTree {
185                                    level: level_a - 1,
186                                    pg: PropGen::and(net_b),
187                                    cumulative: PropGen::and(net_b),
188                                },
189                            ],
190                        };
191                        aig_chains.insert(output, chain);
192                    }
193                    RewriteResult::None
194                }
195            }
196            Cell::Xor(val1, val2) if val1.len() == 1 => {
197                let net1 = val1[0];
198                let net2 = val2[0];
199                let level1 = self.levels.get(net1);
200                let level2 = self.levels.get(net2);
201                let (net_a, net_b, level_a, level_b) = match level1.cmp(&level2) {
202                    Ordering::Less => (net2, net1, level2, level1),
203                    Ordering::Equal => return RewriteResult::None,
204                    Ordering::Greater => (net1, net2, level1, level2),
205                };
206                let mut xor_chains = self.xor_chains.borrow_mut();
207                if let Some(chain) = xor_chains.get(&net_a) {
208                    let mut chain = chain.clone();
209                    chain.min_level = chain.min_level.max(level_b);
210                    if chain.full_trees.len() == 1 {
211                        if chain.full_trees[0].level > level_b {
212                            chain.full_trees[0].cumulative_net = output;
213                            chain.full_trees.push(XorFullTree { level: level_b, net: net_b, cumulative_net: net_b });
214                            xor_chains.insert(output, chain);
215                        }
216                        return RewriteResult::None;
217                    }
218                    // adjust levels of everything to at least the new min_level
219                    let mut top = chain.full_trees.pop().unwrap();
220                    while top.level < chain.min_level {
221                        if let Some(&next_top) = chain.full_trees.last()
222                            && next_top.level <= chain.min_level
223                        {
224                            top.level = next_top.level + 1;
225                            top.net = next_top.cumulative_net;
226                            top.cumulative_net = next_top.cumulative_net;
227                            chain.full_trees.pop();
228                        } else {
229                            top.level = chain.min_level;
230                            break;
231                        }
232                    }
233                    chain.full_trees.push(top);
234                    // add the new input; merge last two entries until invariant holds
235                    let mut level_top = chain.min_level;
236                    let mut net_top = net_b;
237                    while let Some(&next_top) = chain.full_trees.last()
238                        && next_top.level == level_top
239                    {
240                        chain.full_trees.pop();
241                        let val = rewriter.add_cell(Cell::Xor(net_top.into(), next_top.net.into()));
242                        net_top = val[0];
243                        level_top += 1;
244                    }
245                    // don't push the new last entry just yet; compute cumulative_net first
246                    let mut cumulative_net = net_top;
247                    for subtree in chain.full_trees.iter_mut().rev() {
248                        let val = rewriter.add_cell(Cell::Xor(cumulative_net.into(), subtree.net.into()));
249                        cumulative_net = val[0];
250                        subtree.cumulative_net = cumulative_net;
251                    }
252                    // now push the new last entry
253                    chain.full_trees.push(XorFullTree { level: level_top, net: net_top, cumulative_net: net_top });
254                    xor_chains.insert(cumulative_net, chain);
255                    return cumulative_net.into();
256                } else {
257                    let chain = XorChain {
258                        min_level: level_a - 1,
259                        full_trees: vec![
260                            XorFullTree { level: level_a, net: net_a, cumulative_net: output },
261                            XorFullTree { level: level_a - 1, net: net_b, cumulative_net: net_b },
262                        ],
263                    };
264                    xor_chains.insert(output, chain);
265                    RewriteResult::None
266                }
267            }
268            _ => RewriteResult::None,
269        }
270    }
271
272    fn net_replaced(&self, _design: &Design, from: Net, to: Net) {
273        let mut aig_chains = self.aig_chains.borrow_mut();
274        if let Some(chain) = aig_chains.get(&from)
275            && !aig_chains.contains_key(&to)
276        {
277            let chain = chain.clone();
278            aig_chains.insert(to, chain);
279        }
280        let mut xor_chains = self.xor_chains.borrow_mut();
281        if let Some(chain) = xor_chains.get(&from)
282            && !xor_chains.contains_key(&to)
283        {
284            let chain = chain.clone();
285            xor_chains.insert(to, chain);
286        }
287    }
288}
289
290pub fn chain_rebalance(design: &mut Design) {
291    let levels = LevelAnalysis::new();
292    let rebalance = ChainRebalance::new(&levels);
293    design.rewrite(&[&Normalize, &SimpleAigOpt, &levels, &rebalance]);
294}