prjunnamed_generic/
tree_rebalance.rs

1use std::{
2    cell::RefCell,
3    collections::{BTreeSet, BinaryHeap, HashMap, HashSet},
4};
5
6use prjunnamed_netlist::{Cell, ControlNet, Design, Net, RewriteResult, RewriteRuleset};
7
8use crate::{LevelAnalysis, Normalize, SimpleAigOpt};
9
10struct TreeRebalance<'a> {
11    levels: &'a LevelAnalysis,
12    inner_aig: HashSet<Net>,
13    inner_xor: HashSet<Net>,
14    aig_trees: RefCell<HashMap<Net, BTreeSet<ControlNet>>>,
15    xor_trees: RefCell<HashMap<Net, BTreeSet<Net>>>,
16}
17
18impl<'a> TreeRebalance<'a> {
19    fn new(design: &Design, levels: &'a LevelAnalysis) -> Self {
20        let mut inner_aig = HashSet::new();
21        let mut inner_xor = HashSet::new();
22        let mut use_count = HashMap::<Net, u32>::new();
23        for cell in design.iter_cells() {
24            cell.visit(|net| {
25                *use_count.entry(net).or_default() += 1;
26            });
27        }
28        for cell in design.iter_cells() {
29            if let Cell::Aig(net1, net2) = *cell.get() {
30                for net in [net1, net2] {
31                    if let ControlNet::Pos(net) = net
32                        && use_count[&net] == 1
33                    {
34                        inner_aig.insert(net);
35                    }
36                }
37            }
38            if let Cell::Xor(ref val1, ref val2) = *cell.get() {
39                for val in [val1, val2] {
40                    for net in val {
41                        if use_count[&net] == 1 {
42                            inner_xor.insert(net);
43                        }
44                    }
45                }
46            }
47        }
48        Self { levels, inner_aig, inner_xor, aig_trees: Default::default(), xor_trees: Default::default() }
49    }
50}
51
52impl RewriteRuleset for TreeRebalance<'_> {
53    fn rewrite<'a>(
54        &self,
55        cell: &Cell,
56        _meta: prjunnamed_netlist::MetaItemRef<'a>,
57        output: Option<&prjunnamed_netlist::Value>,
58        rewriter: &prjunnamed_netlist::Rewriter<'a>,
59    ) -> RewriteResult<'a> {
60        let Some(output) = output else {
61            return RewriteResult::None;
62        };
63        if output.len() != 1 {
64            return RewriteResult::None;
65        }
66        let output = output[0];
67        match *cell {
68            Cell::Aig(net1, net2) => {
69                let mut aig_trees = self.aig_trees.borrow_mut();
70                let mut inputs1 = if let ControlNet::Pos(net) = net1
71                    && let Some(inputs) = aig_trees.remove(&net)
72                {
73                    inputs
74                } else {
75                    BTreeSet::from_iter([net1])
76                };
77                let mut inputs2 = if let ControlNet::Pos(net) = net2
78                    && let Some(inputs) = aig_trees.remove(&net)
79                {
80                    inputs
81                } else {
82                    BTreeSet::from_iter([net2])
83                };
84                if inputs1.len() < inputs2.len() {
85                    std::mem::swap(&mut inputs1, &mut inputs2);
86                }
87                inputs1.extend(inputs2);
88                let inputs = inputs1;
89                if self.inner_aig.contains(&output) {
90                    aig_trees.insert(output, inputs);
91                    RewriteResult::None
92                } else {
93                    if inputs.len() == 2 {
94                        return RewriteResult::None;
95                    }
96                    let mut inputs = BinaryHeap::from_iter(
97                        inputs.into_iter().map(|net| std::cmp::Reverse((self.levels.get(net.net()), net))),
98                    );
99                    while inputs.len() > 1 {
100                        let (lvl1, net1) = inputs.pop().unwrap().0;
101                        let (lvl2, net2) = inputs.pop().unwrap().0;
102                        let lvl = lvl1.max(lvl2) + 1;
103                        let val = rewriter.add_cell(Cell::Aig(net1, net2));
104                        let net = ControlNet::Pos(val[0]);
105                        inputs.push(std::cmp::Reverse((lvl, net)));
106                    }
107                    let net = inputs.pop().unwrap().0.1;
108                    net.into()
109                }
110            }
111            Cell::Xor(ref val1, ref val2) => {
112                let net1 = val1[0];
113                let net2 = val2[0];
114                let mut xor_trees = self.xor_trees.borrow_mut();
115                let mut inputs1 =
116                    if let Some(inputs) = xor_trees.remove(&net1) { inputs } else { BTreeSet::from_iter([net1]) };
117                let mut inputs2 =
118                    if let Some(inputs) = xor_trees.remove(&net2) { inputs } else { BTreeSet::from_iter([net2]) };
119                if inputs1.len() < inputs2.len() {
120                    std::mem::swap(&mut inputs1, &mut inputs2);
121                }
122                for net in inputs2 {
123                    if !inputs1.remove(&net) {
124                        inputs1.insert(net);
125                    }
126                }
127                let inputs = inputs1;
128                if self.inner_xor.contains(&output) {
129                    xor_trees.insert(output, inputs);
130                    RewriteResult::None
131                } else {
132                    if inputs.len() == 2 {
133                        return RewriteResult::None;
134                    }
135                    let mut inputs = BinaryHeap::from_iter(
136                        inputs.into_iter().map(|net| std::cmp::Reverse((self.levels.get(net), net))),
137                    );
138                    while inputs.len() > 1 {
139                        let (lvl1, net1) = inputs.pop().unwrap().0;
140                        let (lvl2, net2) = inputs.pop().unwrap().0;
141                        let lvl = lvl1.max(lvl2) + 1;
142                        let val = rewriter.add_cell(Cell::Xor(net1.into(), net2.into()));
143                        inputs.push(std::cmp::Reverse((lvl, val[0])));
144                    }
145                    let net = inputs.pop().unwrap().0.1;
146                    net.into()
147                }
148            }
149            _ => RewriteResult::None,
150        }
151    }
152
153    fn net_replaced(&self, _design: &Design, from: Net, to: Net) {
154        let mut aig_trees = self.aig_trees.borrow_mut();
155        if let Some(tree) = aig_trees.remove(&from) {
156            aig_trees.insert(to, tree);
157        }
158        let mut xor_trees = self.xor_trees.borrow_mut();
159        if let Some(tree) = xor_trees.remove(&from) {
160            xor_trees.insert(to, tree);
161        }
162    }
163}
164
165pub fn tree_rebalance(design: &mut Design) {
166    let levels = LevelAnalysis::new();
167    let rebalance = TreeRebalance::new(&design, &levels);
168    design.rewrite(&[&Normalize, &SimpleAigOpt, &levels, &rebalance]);
169}