Skip to main content

prjunnamed_netlist/
smt.rs

1use std::{borrow::Cow, cell::RefCell, collections::BTreeMap};
2
3use crate::{AssignCell, Cell, Const, ControlNet, Design, MatchCell, Net, Trit, Value};
4
5#[cfg(feature = "easy-smt")]
6pub mod easy_smt;
7
8#[derive(Debug, Clone)]
9pub enum SmtResponse {
10    Sat,
11    Unsat,
12    Unknown,
13}
14
15struct SmtTritVec<SMT: SmtEngine> {
16    x: SMT::BitVec, // is undef?
17    y: SMT::BitVec, // if not undef, is one?
18}
19
20pub trait SmtEngine {
21    type Bool: Clone;
22    type BitVec: Clone;
23
24    fn build_bool_lit(&self, value: bool) -> Self::Bool;
25    fn build_bool_eq(&self, arg1: Self::Bool, arg2: Self::Bool) -> Self::Bool;
26    fn build_bool_ite(&self, cond: Self::Bool, if_true: Self::Bool, if_false: Self::Bool) -> Self::Bool;
27    fn build_bool_let(&self, var: &str, expr: Self::Bool, body: impl FnOnce(Self::Bool) -> Self::Bool) -> Self::Bool;
28    fn build_not(&self, arg: Self::Bool) -> Self::Bool;
29    fn build_and(&self, args: &[Self::Bool]) -> Self::Bool;
30    fn build_or(&self, args: &[Self::Bool]) -> Self::Bool;
31    fn build_xor(&self, args: &[Self::Bool]) -> Self::Bool;
32
33    fn build_bitvec_lit(&self, value: &Const) -> Self::BitVec;
34    fn build_bitvec_eq(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
35    fn build_bitvec_ite(&self, cond: Self::Bool, if_true: Self::BitVec, if_false: Self::BitVec) -> Self::BitVec;
36    fn build_bitvec_let(
37        &self,
38        var: &str,
39        expr: Self::BitVec,
40        body: impl FnOnce(Self::BitVec) -> Self::BitVec,
41    ) -> Self::BitVec;
42    fn build_concat(&self, arg_msb: Self::BitVec, arg_lsb: Self::BitVec) -> Self::BitVec;
43    fn build_extract(&self, index_msb: usize, index_lsb: usize, arg: Self::BitVec) -> Self::BitVec;
44    fn build_bvnot(&self, arg1: Self::BitVec) -> Self::BitVec;
45    fn build_bvand(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
46    fn build_bvor(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
47    fn build_bvxor(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
48    fn build_bvadd(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
49    fn build_bvcomp(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
50    fn build_bvult(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
51    fn build_bvslt(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
52    fn build_bvshl(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
53    fn build_bvlshr(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
54    fn build_bvashr(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
55    fn build_bvmul(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
56    fn build_bvudiv(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
57    fn build_bvurem(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
58    fn build_bvsdiv(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
59    fn build_bvsrem(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
60
61    type Error;
62
63    fn declare_bool_const(&self, name: &str) -> Result<Self::Bool, Self::Error>;
64    fn declare_bitvec_const(&self, name: &str, width: usize) -> Result<Self::BitVec, Self::Error>;
65    fn assert(&mut self, term: Self::Bool) -> Result<(), Self::Error>;
66
67    fn check(&mut self) -> Result<SmtResponse, Self::Error>;
68    fn get_bool(&self, term: &Self::Bool) -> Result<bool, Self::Error>;
69    fn get_bitvec(&self, term: &Self::BitVec) -> Result<Const, Self::Error>;
70}
71
72impl<SMT: SmtEngine> Clone for SmtTritVec<SMT> {
73    fn clone(&self) -> Self {
74        Self { x: self.x.clone(), y: self.y.clone() }
75    }
76}
77
78pub struct SmtBuilder<'a, SMT: SmtEngine> {
79    design: &'a Design,
80    engine: SMT,
81    temp: usize,
82    curr: RefCell<BTreeMap<Net, SmtTritVec<SMT>>>,
83    past: RefCell<BTreeMap<Net, SmtTritVec<SMT>>>,
84    eqs: RefCell<Vec<SMT::Bool>>,
85}
86
87impl<'a, SMT: SmtEngine> SmtBuilder<'a, SMT> {
88    pub fn new(design: &'a Design, engine: SMT) -> Self {
89        Self {
90            design,
91            engine,
92            temp: 0,
93            curr: RefCell::new(BTreeMap::new()),
94            past: RefCell::new(BTreeMap::new()),
95            eqs: RefCell::new(Vec::new()),
96        }
97    }
98
99    fn bv_lit<'b>(&self, value: impl Into<Cow<'b, Const>>) -> SMT::BitVec {
100        self.engine.build_bitvec_lit(&value.into())
101    }
102
103    fn bv_const(&self, prefix: &str, index: usize, suffix: &str, width: usize) -> Result<SMT::BitVec, SMT::Error> {
104        self.engine.declare_bitvec_const(&format!("{prefix}{index}{suffix}"), width)
105    }
106
107    fn bv_bind(&mut self, bv_value: SMT::BitVec, width: usize) -> Result<SMT::BitVec, SMT::Error> {
108        let bv_temp = self.engine.declare_bitvec_const(&format!("t{}", self.temp), width)?;
109        self.temp += 1;
110        self.engine.assert(self.engine.build_bitvec_eq(bv_temp.clone(), bv_value))?;
111        Ok(bv_temp)
112    }
113
114    fn bv_concat(&self, elems: impl IntoIterator<IntoIter: DoubleEndedIterator<Item = SMT::BitVec>>) -> SMT::BitVec {
115        let mut bv_result = None;
116        for elem in elems.into_iter().rev() {
117            bv_result = Some(match bv_result {
118                Some(bv_msb) => self.engine.build_concat(bv_msb, elem),
119                None => elem,
120            })
121        }
122        bv_result.expect("SMT bit vectors cannot be empty")
123    }
124
125    fn bv_is_zero(&self, bv_value: SMT::BitVec, width: usize) -> SMT::Bool {
126        self.engine.build_bitvec_eq(bv_value, self.bv_lit(Const::zero(width)))
127    }
128
129    fn bv_to_bool(&self, bv_value: SMT::BitVec) -> SMT::Bool {
130        self.engine.build_bitvec_eq(bv_value, self.bv_lit(Trit::One))
131    }
132
133    fn bool_to_bv(&self, bool_value: SMT::Bool) -> SMT::BitVec {
134        self.engine.build_bitvec_ite(bool_value, self.bv_lit(Trit::One), self.bv_lit(Trit::Zero))
135    }
136
137    fn tv_lit<'b>(&self, value: impl Into<Cow<'b, Const>>) -> SmtTritVec<SMT> {
138        let value = value.into();
139        let is_undef = Const::from_iter(value.iter().map(|t| if t == Trit::Undef { Trit::One } else { Trit::Zero }));
140        let is_one = Const::from_iter(value.iter().map(|t| if t == Trit::One { Trit::One } else { Trit::Zero }));
141        SmtTritVec { x: self.bv_lit(is_undef), y: self.bv_lit(is_one) }
142    }
143
144    fn tv_const(&self, prefix: &str, index: usize, width: usize) -> Result<SmtTritVec<SMT>, SMT::Error> {
145        Ok(SmtTritVec { x: self.bv_const(prefix, index, "x", width)?, y: self.bv_const(prefix, index, "y", width)? })
146    }
147
148    fn tv_bind(&mut self, tv_value: SmtTritVec<SMT>, width: usize) -> Result<SmtTritVec<SMT>, SMT::Error> {
149        let bv_temp_x = self.engine.declare_bitvec_const(&format!("t{}x", self.temp), width)?;
150        let bv_temp_y = self.engine.declare_bitvec_const(&format!("t{}y", self.temp), width)?;
151        self.temp += 1;
152        self.engine.assert(self.engine.build_bitvec_eq(bv_temp_x.clone(), tv_value.x))?;
153        self.engine.assert(self.engine.build_bitvec_eq(bv_temp_y.clone(), tv_value.y))?;
154        Ok(SmtTritVec { x: bv_temp_x, y: bv_temp_y })
155    }
156
157    fn tv_concat<I>(&self, elems: I) -> SmtTritVec<SMT>
158    where
159        I: IntoIterator<IntoIter: DoubleEndedIterator<Item = SmtTritVec<SMT>>>,
160    {
161        let mut bv_result = None;
162        for elem in elems.into_iter().rev() {
163            bv_result = Some(match bv_result {
164                Some((bv_x_msb, bv_y_msb)) => {
165                    (self.engine.build_concat(bv_x_msb, elem.x), self.engine.build_concat(bv_y_msb, elem.y))
166                }
167                None => (elem.x, elem.y),
168            })
169        }
170        bv_result.map(|(x, y)| SmtTritVec { x, y }).expect("SMT bit vectors cannot be empty")
171    }
172
173    fn tv_extract(&self, index_msb: usize, index_lsb: usize, tv: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
174        SmtTritVec {
175            x: self.engine.build_extract(index_msb, index_lsb, tv.x),
176            y: self.engine.build_extract(index_msb, index_lsb, tv.y),
177        }
178    }
179
180    fn tv_is0(&self, tv: SmtTritVec<SMT>) -> SMT::BitVec {
181        self.engine.build_bvand(self.engine.build_bvnot(tv.y), self.engine.build_bvnot(tv.x))
182    }
183
184    fn tv_is1(&self, tv: SmtTritVec<SMT>) -> SMT::BitVec {
185        self.engine.build_bvand(tv.y, self.engine.build_bvnot(tv.x))
186    }
187
188    fn tv_eq(&self, tv_a: SmtTritVec<SMT>, tv_b: SmtTritVec<SMT>) -> SMT::Bool {
189        let x_eq = self.engine.build_bitvec_eq(tv_a.x, tv_b.x);
190        let y_eq = self.engine.build_bitvec_eq(tv_a.y, tv_b.y);
191        self.engine.build_and(&[x_eq, y_eq])
192    }
193
194    fn tv_refines(&self, from: SmtTritVec<SMT>, to: SmtTritVec<SMT>) -> SMT::Bool {
195        let from_def = self.engine.build_bvnot(from.x.clone());
196        let x_refine = self.engine.build_bitvec_eq(
197            to.x.clone(),                          // if a bit is X in `to`...
198            self.engine.build_bvand(to.x, from.x), // ... it should be X in `from`.
199        );
200        let y_refine = self.engine.build_bitvec_eq(
201            self.engine.build_bvand(from_def.clone(), from.y), // if a bit was 0/1 in `from`...
202            self.engine.build_bvand(from_def.clone(), to.y),   // ... it should be that same value in `to`.
203        );
204        self.engine.build_and(&[x_refine, y_refine])
205    }
206
207    fn tv_not(&self, tv_a: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
208        SmtTritVec { y: self.engine.build_bvnot(tv_a.y), x: tv_a.x }
209    }
210
211    fn tv_and(
212        &mut self,
213        tv_a: SmtTritVec<SMT>,
214        tv_b: SmtTritVec<SMT>,
215        width: usize,
216    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
217        //   0 1 X
218        // 0 0 0 0
219        // 1 0 1 X
220        // X 0 X X
221        let (tv_a, tv_b) = (self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
222        let bv_arg1_is0 = self.tv_is0(tv_a.clone());
223        let bv_arg2_is0 = self.tv_is0(tv_b.clone());
224        Ok(SmtTritVec {
225            y: self.engine.build_bvand(tv_a.y, tv_b.y),
226            x: self.engine.build_bvand(
227                self.engine.build_bvor(tv_a.x, tv_b.x),
228                self.engine.build_bvnot(self.engine.build_bvor(bv_arg1_is0, bv_arg2_is0)),
229            ),
230        })
231    }
232
233    fn tv_or(
234        &mut self,
235        tv_a: SmtTritVec<SMT>,
236        tv_b: SmtTritVec<SMT>,
237        width: usize,
238    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
239        //   0 1 X
240        // 0 0 1 X
241        // 1 1 1 1
242        // X X 1 1
243        let (tv_a, tv_b) = (self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
244        let bv_arg1_is1 = self.tv_is1(tv_a.clone());
245        let bv_arg2_is1 = self.tv_is1(tv_b.clone());
246        Ok(SmtTritVec {
247            y: self.engine.build_bvor(tv_a.y, tv_b.y),
248            x: self.engine.build_bvand(
249                self.engine.build_bvor(tv_a.x, tv_b.x),
250                self.engine.build_bvnot(self.engine.build_bvor(bv_arg1_is1, bv_arg2_is1)),
251            ),
252        })
253    }
254
255    fn tv_xor(&self, tv_a: SmtTritVec<SMT>, tv_b: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
256        //   0 1 X
257        // 0 0 1 X
258        // 1 1 0 X
259        // X X X X
260        SmtTritVec { y: self.engine.build_bvxor(tv_a.y, tv_b.y), x: self.engine.build_bvor(tv_a.x, tv_b.x) }
261    }
262
263    fn tv_mux(
264        &mut self,
265        tv_s: SmtTritVec<SMT>,
266        tv_a: SmtTritVec<SMT>,
267        tv_b: SmtTritVec<SMT>,
268        width: usize,
269    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
270        // S A B O    S A B O    S A B O
271        // 0 ? 0 0    1 0 ? 0    X 0 0 0
272        // 0 ? 1 1    1 1 ? 1    X 1 1 1
273        // 0 ? X X    1 X ? X    X ? ? X
274        let (tv_s, tv_a, tv_b) = (self.tv_bind(tv_s, 1)?, self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
275        let (bv_s_is0, bv_s_is1) = (self.tv_is0(tv_s.clone()), self.tv_is1(tv_s));
276        let (bool_s_is0, bool_s_is1) = (self.bv_to_bool(bv_s_is0), self.bv_to_bool(bv_s_is1));
277        let bv_x_sx = self.engine.build_bvxor(tv_a.y.clone(), tv_b.y.clone());
278        let bv_x_sx = self.engine.build_bvor(bv_x_sx, tv_a.x.clone());
279        let bv_x_sx = self.engine.build_bvor(bv_x_sx, tv_b.x.clone());
280        let bv_x_s1 = self.engine.build_bitvec_ite(bool_s_is1.clone(), tv_a.x, bv_x_sx);
281        let bv_x_s0 = self.engine.build_bitvec_ite(bool_s_is0.clone(), tv_b.x, bv_x_s1);
282        let bv_y_sx = self.engine.build_bvand(tv_a.y.clone(), tv_b.y.clone());
283        let bv_y_s1 = self.engine.build_bitvec_ite(bool_s_is1, tv_a.y, bv_y_sx);
284        let bv_y_s0 = self.engine.build_bitvec_ite(bool_s_is0, tv_b.y, bv_y_s1);
285        Ok(SmtTritVec { y: bv_y_s0, x: bv_x_s0 })
286    }
287
288    fn curr_net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
289        if let Some(trit) = net.as_const() {
290            Ok(self.tv_lit(trit))
291        } else {
292            let cell_index = net.as_cell_index();
293            let tv_net = self.tv_const("n", cell_index, 1)?;
294            self.curr.borrow_mut().insert(net, tv_net.clone());
295            Ok(tv_net)
296        }
297    }
298
299    fn curr_value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
300        Ok(self.tv_concat(value.iter().map(|net| self.curr_net(net)).collect::<Result<Vec<_>, _>>()?))
301    }
302
303    fn past_net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
304        if let Some(trit) = net.as_const() {
305            Ok(self.tv_lit(trit))
306        } else {
307            let cell_index = net.as_cell_index();
308            let tv_net = self.tv_const("p", cell_index, 1)?;
309            self.curr.borrow_mut().insert(net, tv_net.clone());
310            Ok(tv_net)
311        }
312    }
313
314    fn past_value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
315        Ok(self.tv_concat(value.iter().map(|net| self.past_net(net)).collect::<Result<Vec<_>, _>>()?))
316    }
317
318    fn net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
319        let index = net.as_cell_index();
320        if !self.design.is_valid_cell_index(index) {
321            return self.curr_net(net); // FIXME: is this even sound?
322        }
323        let (cell_ref, _offset) = self.design.find_cell(net);
324        if matches!(&*cell_ref.get(), Cell::Dff(_)) { self.past_net(net) } else { self.curr_net(net) }
325    }
326
327    fn value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
328        Ok(self.tv_concat(value.iter().map(|net| self.net(net)).collect::<Result<Vec<_>, _>>()?))
329    }
330
331    fn invert_net(&self, control_net: ControlNet, tv_net: SmtTritVec<SMT>) -> Result<SmtTritVec<SMT>, SMT::Error> {
332        let active = if control_net.is_positive() { Const::ones(1) } else { Const::zero(1) };
333        Ok(SmtTritVec { y: self.engine.build_bvcomp(tv_net.y, self.engine.build_bitvec_lit(&active)), x: tv_net.x })
334    }
335
336    fn control_net(&self, control_net: ControlNet) -> Result<SmtTritVec<SMT>, SMT::Error> {
337        self.invert_net(control_net, self.net(control_net.net())?)
338    }
339
340    fn clock_net(&mut self, control_net: ControlNet) -> Result<SmtTritVec<SMT>, SMT::Error> {
341        self.tv_and(
342            self.invert_net(control_net, self.curr_net(control_net.net())?)?,
343            self.tv_not(self.invert_net(control_net, self.past_net(control_net.net())?)?),
344            1,
345        )
346    }
347
348    fn cell(&mut self, output: &Value, cell: &Cell) -> Result<SmtTritVec<SMT>, SMT::Error> {
349        let prepare_shift = |value: &Value, amount: &Value, stride: u32, signed: bool| -> Result<_, SMT::Error> {
350            let width = value.len().max(amount.len() + stride.max(1).ilog2() as usize + 1);
351            let value_ext = if signed { value.sext(width) } else { value.zext(width) };
352            let amount_ext = amount.zext(width);
353            let (tv_value_ext, tv_amount_ext) = (self.value(&value_ext)?, self.value(&amount_ext)?);
354            let bv_stride = self.engine.build_bitvec_lit(&Const::from_uint(stride as u128, width));
355            let bv_amount_mul = self.engine.build_bvmul(tv_amount_ext.y.clone(), bv_stride);
356            Ok((width, tv_value_ext, tv_amount_ext, bv_amount_mul))
357        };
358
359        let bv_cell = match &cell {
360            Cell::Const(a) => self.tv_lit(*a),
361            Cell::Buf(a) => self.value(a)?,
362            Cell::Not(a) => self.tv_not(self.value(a)?),
363            Cell::And(a, b) => self.tv_and(self.value(a)?, self.value(b)?, output.len())?,
364            Cell::Or(a, b) => self.tv_or(self.value(a)?, self.value(b)?, output.len())?,
365            Cell::Xor(a, b) => self.tv_xor(self.value(a)?, self.value(b)?),
366            Cell::Mux(s, a, b) => self.tv_mux(self.net(*s)?, self.value(a)?, self.value(b)?, output.len())?,
367            Cell::Adc(a, b, c) => {
368                if a.is_empty() {
369                    self.net(*c)?
370                } else {
371                    let (tv_a, tv_b, tv_c) = (self.value(a)?, self.value(b)?, self.net(*c)?);
372                    let (tv_a, tv_b) = (self.tv_bind(tv_a, a.len())?, self.tv_bind(tv_b, b.len())?);
373                    let mut bv_x = vec![];
374                    let mut bv_carry_x = tv_c.x;
375                    for index in 0..a.len() {
376                        bv_carry_x = self.bv_bind(
377                            self.engine.build_bvor(
378                                bv_carry_x,
379                                self.engine.build_bvor(
380                                    self.engine.build_extract(index, index, tv_a.x.clone()),
381                                    self.engine.build_extract(index, index, tv_b.x.clone()),
382                                ),
383                            ),
384                            1,
385                        )?;
386                        bv_x.push(bv_carry_x.clone());
387                    }
388                    bv_x.push(bv_carry_x.clone());
389                    SmtTritVec {
390                        x: self.bv_concat(bv_x),
391                        y: self.engine.build_bvadd(
392                            self.engine.build_bvadd(
393                                self.engine.build_concat(self.bv_lit(Trit::Zero), tv_a.y),
394                                self.engine.build_concat(self.bv_lit(Trit::Zero), tv_b.y),
395                            ),
396                            self.engine.build_concat(self.bv_lit(Const::zero(a.len())), tv_c.y),
397                        ),
398                    }
399                }
400            }
401            Cell::Aig(a, b) => self.tv_and(self.control_net(*a)?, self.control_net(*b)?, 1)?,
402            Cell::Eq(a, b) => {
403                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
404                let bv_a_xor_b = self.engine.build_bvxor(tv_a.y.clone(), tv_b.y.clone());
405                let bv_unequal = self.engine.build_bvand(
406                    bv_a_xor_b,
407                    self.engine.build_bvnot(self.engine.build_bvor(tv_a.x.clone(), tv_b.x.clone())),
408                );
409                let bool_any_unequal = self.engine.build_not(self.bv_is_zero(bv_unequal, a.len()));
410                SmtTritVec {
411                    x: self.bool_to_bv(self.engine.build_bool_ite(
412                        bool_any_unequal,
413                        self.engine.build_bool_lit(false),
414                        self.engine.build_not(self.engine.build_and(&[
415                            self.bv_is_zero(tv_a.x.clone(), a.len()),
416                            self.bv_is_zero(tv_b.x.clone(), b.len()),
417                        ])),
418                    )),
419                    y: self.engine.build_bvcomp(tv_a.y, tv_b.y),
420                }
421            }
422            Cell::ULt(a, b) => {
423                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
424                SmtTritVec {
425                    x: self.bool_to_bv(self.engine.build_not(
426                        self.engine.build_and(&[self.bv_is_zero(tv_a.x, a.len()), self.bv_is_zero(tv_b.x, b.len())]),
427                    )),
428                    y: self.bool_to_bv(self.engine.build_bvult(tv_a.y, tv_b.y)),
429                }
430            }
431            Cell::SLt(a, b) => {
432                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
433                SmtTritVec {
434                    x: self.bool_to_bv(self.engine.build_not(
435                        self.engine.build_and(&[self.bv_is_zero(tv_a.x, a.len()), self.bv_is_zero(tv_b.x, b.len())]),
436                    )),
437                    y: self.bool_to_bv(self.engine.build_bvslt(tv_a.y, tv_b.y)),
438                }
439            }
440            Cell::Shl(value, amount, stride) => {
441                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, false)?;
442                let tv_result = SmtTritVec {
443                    x: self.engine.build_bitvec_ite(
444                        self.bv_is_zero(tv_amount_ext.x, width),
445                        self.engine.build_bvshl(tv_value_ext.x, bv_amount_mul.clone()),
446                        self.engine.build_bitvec_lit(&Const::ones(width)),
447                    ),
448                    y: self.engine.build_bvshl(tv_value_ext.y, bv_amount_mul),
449                };
450                self.tv_extract(value.len() - 1, 0, tv_result)
451            }
452            Cell::UShr(value, amount, stride) => {
453                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, false)?;
454                let tv_result = SmtTritVec {
455                    x: self.engine.build_bitvec_ite(
456                        self.bv_is_zero(tv_amount_ext.x, width),
457                        self.engine.build_bvlshr(tv_value_ext.x, bv_amount_mul.clone()),
458                        self.engine.build_bitvec_lit(&Const::ones(width)),
459                    ),
460                    y: self.engine.build_bvlshr(tv_value_ext.y, bv_amount_mul),
461                };
462                self.tv_extract(value.len() - 1, 0, tv_result)
463            }
464            Cell::SShr(value, amount, stride) => {
465                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, true)?;
466                let tv_result = SmtTritVec {
467                    x: self.engine.build_bitvec_ite(
468                        self.bv_is_zero(tv_amount_ext.x, width),
469                        self.engine.build_bvashr(tv_value_ext.x, bv_amount_mul.clone()),
470                        self.engine.build_bitvec_lit(&Const::ones(width)),
471                    ),
472                    y: self.engine.build_bvashr(tv_value_ext.y, bv_amount_mul),
473                };
474                self.tv_extract(value.len() - 1, 0, tv_result)
475            }
476            Cell::XShr(value, amount, stride) => {
477                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) =
478                    prepare_shift(&value.concat(Value::undef(1)), amount, *stride, true)?;
479                let tv_result = SmtTritVec {
480                    x: self.engine.build_bitvec_ite(
481                        self.bv_is_zero(tv_amount_ext.x, width),
482                        self.engine.build_bvashr(tv_value_ext.x, bv_amount_mul.clone()),
483                        self.engine.build_bitvec_lit(&Const::ones(width)),
484                    ),
485                    y: self.engine.build_bvlshr(tv_value_ext.y, bv_amount_mul),
486                };
487                self.tv_extract(value.len() - 1, 0, tv_result)
488            }
489            Cell::Mul(a, b) => {
490                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
491                SmtTritVec {
492                    x: self.engine.build_bitvec_ite(
493                        self.engine.build_and(&[self.bv_is_zero(tv_a.x, a.len()), self.bv_is_zero(tv_b.x, b.len())]),
494                        self.engine.build_bitvec_lit(&Const::zero(a.len())),
495                        self.engine.build_bitvec_lit(&Const::ones(a.len())),
496                    ),
497                    y: self.engine.build_bvmul(tv_a.y, tv_b.y),
498                }
499            }
500            Cell::UDiv(..) => unimplemented!("lowering of udiv to SMT-LIB is not implemented"),
501            Cell::UMod(..) => unimplemented!("lowering of umod to SMT-LIB is not implemented"),
502            Cell::SDivTrunc(..) => unimplemented!("lowering of sdiv_trunc to SMT-LIB is not implemented"),
503            Cell::SDivFloor(..) => unimplemented!("lowering of sdiv_floor to SMT-LIB is not implemented"),
504            Cell::SModTrunc(..) => unimplemented!("lowering of smod_trunc to SMT-LIB is not implemented"),
505            Cell::SModFloor(..) => unimplemented!("lowering of smod_floor to SMT-LIB is not implemented"),
506            Cell::Match(MatchCell { value, enable, patterns }) => {
507                let mut tv_matches = vec![];
508                for alternates in patterns {
509                    let mut tv_sum = self.tv_lit(Trit::Zero);
510                    for pattern in alternates {
511                        let mut tv_prd = self.tv_lit(Trit::One);
512                        for (index, mask) in pattern.iter().enumerate().filter(|(_index, mask)| *mask != Trit::Undef) {
513                            let tv_eq = self.tv_not(self.tv_xor(self.net(value[index])?, self.net(mask.into())?));
514                            tv_prd = self.tv_and(tv_prd, tv_eq, 1)?;
515                        }
516                        tv_sum = self.tv_or(tv_sum, tv_prd, 1)?;
517                    }
518                    tv_matches.push(tv_sum);
519                }
520                let tv_matches = self.tv_bind(self.tv_concat(tv_matches), patterns.len())?;
521                let tv_enable = self.net(*enable)?;
522                let tv_all_cold = self.tv_lit(Const::zero(patterns.len()));
523                let mut tv_result = tv_all_cold.clone();
524                for index in (0..patterns.len()).rev() {
525                    tv_result = self.tv_mux(
526                        self.tv_extract(index, index, tv_matches.clone()),
527                        self.tv_lit(Const::one_hot(patterns.len(), index)),
528                        tv_result,
529                        patterns.len(),
530                    )?;
531                }
532                self.tv_mux(tv_enable.clone(), tv_result, tv_all_cold, patterns.len())?
533            }
534            Cell::Assign(AssignCell { value, enable, update, offset }) => self.tv_mux(
535                self.net(*enable)?,
536                self.value(&{
537                    let mut nets = Vec::from_iter(value.iter());
538                    nets[*offset..(*offset + update.len())].copy_from_slice(&update[..]);
539                    Value::from_iter(nets)
540                })?,
541                self.value(value)?,
542                output.len(),
543            )?,
544            Cell::Dff(flip_flop) => {
545                let mut data = self.value(&flip_flop.data)?;
546                let clear = self.control_net(flip_flop.clear)?;
547                let load = self.control_net(flip_flop.load)?;
548                let load_data = self.value(&flip_flop.load_data)?;
549                let reset = self.control_net(flip_flop.reset)?;
550                let enable = self.control_net(flip_flop.enable)?;
551                if flip_flop.reset_over_enable {
552                    data = self.tv_mux(enable, data, self.past_value(output)?, output.len())?;
553                    data = self.tv_mux(reset, self.tv_lit(&flip_flop.reset_value), data, output.len())?;
554                } else {
555                    data = self.tv_mux(reset, self.tv_lit(&flip_flop.reset_value), data, output.len())?;
556                    data = self.tv_mux(enable, data, self.past_value(output)?, output.len())?;
557                }
558                let active_edge = self.clock_net(flip_flop.clock)?;
559                let mut value = self.tv_mux(active_edge, data, self.past_value(output)?, output.len())?;
560                if flip_flop.has_load() {
561                    value = self.tv_mux(load, load_data, value, output.len())?
562                }
563                if flip_flop.has_clear() {
564                    value = self.tv_mux(clear, self.tv_lit(&flip_flop.clear_value), value, output.len())?
565                }
566                value
567            }
568            Cell::Memory(_memory) => unimplemented!("memories are not lowered to SMT-LIB yet"),
569            Cell::IoBuf(_io_buffer) => self.value(output)?, // i/en/o treated as POs/PIs
570            Cell::Target(_target_cell) => unimplemented!("target cells cannot be lowered to SMT-LIB yet"),
571            Cell::Other(_) => unreachable!("instances cannot be lowered to SMT-LIB"),
572            Cell::Input(..) | Cell::Output(..) | Cell::Name(..) | Cell::Debug(..) => unreachable!(),
573        };
574
575        Ok(bv_cell)
576    }
577
578    pub fn add_cell(&mut self, output: &Value, cell: &Cell) -> Result<(), SMT::Error> {
579        // Declare the nets used by the cell so that it is present in the counterexample even if unused.
580        if let Cell::Input(..) = cell {
581            self.curr_value(output)?;
582            return Ok(());
583        }
584        let tv_cell = self.cell(output, cell)?;
585        self.engine.assert(self.tv_eq(self.curr_value(output)?, tv_cell))
586    }
587
588    pub fn replace_cell(&mut self, output: &Value, old_cell: &Cell, new_cell: &Cell) -> Result<(), SMT::Error> {
589        self.add_cell(output, old_cell)?;
590        let tv_new_cell = self.cell(output, new_cell)?;
591        self.eqs.borrow_mut().push(self.tv_refines(self.curr_value(output)?, tv_new_cell));
592        Ok(())
593    }
594
595    pub fn replace_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
596        self.eqs.borrow_mut().push(self.tv_refines(self.curr_net(from_net)?, self.net(to_net)?));
597        Ok(())
598    }
599
600    pub fn replace_void_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
601        self.engine.assert(self.tv_eq(self.curr_net(from_net)?, self.net(to_net)?))
602    }
603
604    pub fn replace_dff_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
605        // Essentially a single induction step.
606        self.engine.assert(self.tv_eq(self.past_net(from_net)?, self.past_net(to_net)?))?;
607        self.eqs.borrow_mut().push(self.tv_refines(self.curr_net(from_net)?, self.curr_net(to_net)?));
608        Ok(())
609    }
610
611    pub fn check(&mut self) -> Result<Option<SmtExample>, SMT::Error> {
612        if self.eqs.borrow().is_empty() {
613            return Ok(None);
614        }
615        let not_and_eqs = self.engine.build_not(self.engine.build_and(&self.eqs.borrow()[..]));
616        self.engine.assert(not_and_eqs)?;
617        match self.engine.check()? {
618            SmtResponse::Unknown => panic!("SMT solver returned unknown"),
619            SmtResponse::Unsat => Ok(None),
620            SmtResponse::Sat => {
621                let get_trit = |tv_net: &SmtTritVec<SMT>| -> Result<Trit, SMT::Error> {
622                    if self.engine.get_bitvec(&tv_net.x)?[0] == Trit::One {
623                        Ok(Trit::Undef)
624                    } else {
625                        Ok(self.engine.get_bitvec(&tv_net.y)?[0])
626                    }
627                };
628                let (mut curr, mut past) = (BTreeMap::new(), BTreeMap::new());
629                for (net, tv_net) in self.curr.borrow().iter() {
630                    curr.insert(*net, get_trit(tv_net)?);
631                }
632                for (net, tv_net) in self.past.borrow().iter() {
633                    past.insert(*net, get_trit(tv_net)?);
634                }
635                Ok(Some(SmtExample { curr, past }))
636            }
637        }
638    }
639}
640
641#[derive(Debug, Clone, PartialEq, Eq)]
642pub struct SmtExample {
643    curr: BTreeMap<Net, Trit>,
644    past: BTreeMap<Net, Trit>,
645}
646
647impl SmtExample {
648    pub fn get_net(&self, net: Net) -> Option<Trit> {
649        self.curr.get(&net).cloned()
650    }
651
652    pub fn get_value<'a>(&self, value: impl Into<Cow<'a, Value>>) -> Option<Const> {
653        let mut result = Const::new();
654        for net in &*value.into() {
655            result.push(self.get_net(net)?);
656        }
657        Some(result)
658    }
659
660    pub fn get_past_net(&self, net: Net) -> Option<Trit> {
661        self.past.get(&net).cloned()
662    }
663
664    pub fn get_past_value<'a>(&self, value: impl Into<Cow<'a, Value>>) -> Option<Const> {
665        let mut result = Const::new();
666        for net in &*value.into() {
667            result.push(self.get_past_net(net)?);
668        }
669        Some(result)
670    }
671}