prjunnamed_generic/
merge.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::{borrow::Cow, collections::HashMap};

use prjunnamed_netlist::{Cell, Design, Value};

struct Numberer(HashMap<Cell, Value>);

impl Numberer {
    fn new() -> Self {
        Numberer(HashMap::new())
    }

    fn find_or_insert<'a>(&mut self, cell: Cell, value: impl Into<Cow<'a, Value>>) -> Value {
        self.0.entry(cell).or_insert_with(|| value.into().into_owned()).clone()
    }

    fn bitwise_unary<F>(&mut self, rebuild: F, arg: Value, out: &Value) -> Value
    where
        F: Fn(Value) -> Cell,
    {
        let mut result = Value::new();
        for (out_net, arg_net) in out.iter().zip(arg.iter()) {
            let bit_cell = rebuild(Value::from(arg_net));
            result.extend(self.find_or_insert(bit_cell, out_net));
        }
        result
    }

    fn commutative_bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
    where
        F: Fn(Value, Value) -> Cell,
    {
        let mut result = Value::new();
        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
            let (arg1_net, arg2_net) = if arg1_net <= arg2_net { (arg1_net, arg2_net) } else { (arg2_net, arg1_net) };
            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
            result.extend(self.find_or_insert(bit_cell, out_net));
        }
        result
    }

    fn bitwise_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
    where
        F: Fn(Value, Value) -> Cell,
    {
        let mut result = Value::new();
        for (out_net, (arg1_net, arg2_net)) in out.iter().zip(arg1.iter().zip(arg2.iter())) {
            let bit_cell = rebuild(Value::from(arg1_net), Value::from(arg2_net));
            result.extend(self.find_or_insert(bit_cell, out_net));
        }
        result
    }

    fn commutative_binary<F>(&mut self, rebuild: F, arg1: Value, arg2: Value, out: &Value) -> Value
    where
        F: Fn(Value, Value) -> Cell,
    {
        let (arg1, arg2) = if arg1 <= arg2 { (arg1, arg2) } else { (arg2, arg1) };
        let cell = rebuild(arg1, arg2);
        self.find_or_insert(cell, out)
    }
}

pub fn merge(design: &mut Design) -> bool {
    let mut numberer = Numberer::new();
    for cell_ref in design.iter_cells_topo().filter(|cell_ref| !cell_ref.get().has_effects(design)) {
        let mut cell = cell_ref.get().into_owned();
        cell.visit_mut(|net| *net = design.map_net(*net));
        let output = cell_ref.output();
        let canon = match cell {
            Cell::Buf(arg) => numberer.bitwise_unary(Cell::Buf, arg, &output),
            Cell::Not(arg) => numberer.bitwise_unary(Cell::Not, arg, &output),
            Cell::And(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::And, arg1, arg2, &output),
            Cell::Or(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Or, arg1, arg2, &output),
            Cell::Xor(arg1, arg2) => numberer.commutative_bitwise_binary(Cell::Xor, arg1, arg2, &output),
            Cell::Mux(arg1, arg2, arg3) => {
                numberer.bitwise_binary(|arg2, arg3| Cell::Mux(arg1, arg2, arg3), arg2, arg3, &output)
            }
            Cell::Adc(arg1, arg2, arg3) => {
                numberer.commutative_binary(|arg1, arg2| Cell::Adc(arg1, arg2, arg3), arg1, arg2, &output)
            }
            Cell::Eq(arg1, arg2) => numberer.commutative_binary(Cell::Eq, arg1, arg2, &output),
            Cell::Mul(arg1, arg2) => numberer.commutative_binary(Cell::Mul, arg1, arg2, &output),
            _ => numberer.find_or_insert(cell, output.clone()),
        };
        if cfg!(feature = "trace") {
            if output != canon {
                eprintln!(">merge {} => {}", design.display_value(&output), design.display_value(&canon));
            }
        }
        for canon_net in canon.iter() {
            let Ok((canon_cell_ref, _offset)) = design.find_cell(canon_net) else { unreachable!() };
            canon_cell_ref.append_metadata(cell_ref.metadata());
        }
        design.replace_value(output, canon);
    }
    design.compact()
}

#[cfg(test)]
mod test {
    use prjunnamed_netlist::{assert_isomorphic, Design, Value};

    use crate::merge::merge;

    #[test]
    fn test_merge_commutative_xor() {
        let mut dl = Design::new();
        let a = dl.add_input("a", 2);
        let b = dl.add_input("b", 2);
        let x1 = dl.add_xor1(a[0], b[0]);
        let x2 = dl.add_xor1(a[1], b[1]);
        let x3 = dl.add_xor(b, a);
        dl.add_output("y", Value::from(x1).concat(x2).concat(x3));
        dl.apply();
        merge(&mut dl);

        let mut dr = Design::new();
        let a = dr.add_input("a", 2);
        let b = dr.add_input("b", 2);
        let x1 = dr.add_xor1(a[0], b[0]);
        let x2 = dr.add_xor1(a[1], b[1]);
        dr.add_output("y", Value::from(x1).concat(x2).concat(x1).concat(x2));
        dr.apply();

        assert_isomorphic!(dl, dr);
    }
}