prjunnamed_netlist/
smt.rs

1use std::{borrow::Cow, cell::RefCell, collections::BTreeMap};
2
3use crate::{AssignCell, Cell, Const, ControlNet, Design, MatchCell, Net, Trit, Value};
4
5#[cfg(feature = "easy-smt")]
6pub mod easy_smt;
7
8#[derive(Debug, Clone)]
9pub enum SmtResponse {
10    Sat,
11    Unsat,
12    Unknown,
13}
14
15struct SmtTritVec<SMT: SmtEngine> {
16    x: SMT::BitVec, // is undef?
17    y: SMT::BitVec, // if not undef, is one?
18}
19
20pub trait SmtEngine {
21    type Bool: Clone;
22    type BitVec: Clone;
23
24    fn build_bool_lit(&self, value: bool) -> Self::Bool;
25    fn build_bool_eq(&self, arg1: Self::Bool, arg2: Self::Bool) -> Self::Bool;
26    fn build_bool_ite(&self, cond: Self::Bool, if_true: Self::Bool, if_false: Self::Bool) -> Self::Bool;
27    fn build_bool_let(&self, var: &str, expr: Self::Bool, body: impl FnOnce(Self::Bool) -> Self::Bool) -> Self::Bool;
28    fn build_not(&self, arg: Self::Bool) -> Self::Bool;
29    fn build_and(&self, args: &[Self::Bool]) -> Self::Bool;
30    fn build_or(&self, args: &[Self::Bool]) -> Self::Bool;
31    fn build_xor(&self, args: &[Self::Bool]) -> Self::Bool;
32
33    fn build_bitvec_lit(&self, value: &Const) -> Self::BitVec;
34    fn build_bitvec_eq(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
35    fn build_bitvec_ite(&self, cond: Self::Bool, if_true: Self::BitVec, if_false: Self::BitVec) -> Self::BitVec;
36    fn build_bitvec_let(
37        &self,
38        var: &str,
39        expr: Self::BitVec,
40        body: impl FnOnce(Self::BitVec) -> Self::BitVec,
41    ) -> Self::BitVec;
42    fn build_concat(&self, arg_msb: Self::BitVec, arg_lsb: Self::BitVec) -> Self::BitVec;
43    fn build_extract(&self, index_msb: usize, index_lsb: usize, arg: Self::BitVec) -> Self::BitVec;
44    fn build_bvnot(&self, arg1: Self::BitVec) -> Self::BitVec;
45    fn build_bvand(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
46    fn build_bvor(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
47    fn build_bvxor(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
48    fn build_bvadd(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
49    fn build_bvcomp(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
50    fn build_bvult(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
51    fn build_bvslt(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::Bool;
52    fn build_bvshl(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
53    fn build_bvlshr(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
54    fn build_bvashr(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
55    fn build_bvmul(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
56    fn build_bvudiv(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
57    fn build_bvurem(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
58    fn build_bvsdiv(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
59    fn build_bvsrem(&self, arg1: Self::BitVec, arg2: Self::BitVec) -> Self::BitVec;
60
61    type Error;
62
63    fn declare_bool_const(&self, name: &str) -> Result<Self::Bool, Self::Error>;
64    fn declare_bitvec_const(&self, name: &str, width: usize) -> Result<Self::BitVec, Self::Error>;
65    fn assert(&mut self, term: Self::Bool) -> Result<(), Self::Error>;
66
67    fn check(&mut self) -> Result<SmtResponse, Self::Error>;
68    fn get_bool(&self, term: &Self::Bool) -> Result<bool, Self::Error>;
69    fn get_bitvec(&self, term: &Self::BitVec) -> Result<Const, Self::Error>;
70}
71
72impl<SMT: SmtEngine> Clone for SmtTritVec<SMT> {
73    fn clone(&self) -> Self {
74        Self { x: self.x.clone(), y: self.y.clone() }
75    }
76}
77
78pub struct SmtBuilder<'a, SMT: SmtEngine> {
79    design: &'a Design,
80    engine: SMT,
81    temp: usize,
82    curr: RefCell<BTreeMap<Net, SmtTritVec<SMT>>>,
83    past: RefCell<BTreeMap<Net, SmtTritVec<SMT>>>,
84    eqs: RefCell<Vec<SMT::Bool>>,
85}
86
87impl<'a, SMT: SmtEngine> SmtBuilder<'a, SMT> {
88    pub fn new(design: &'a Design, engine: SMT) -> Self {
89        Self {
90            design,
91            engine,
92            temp: 0,
93            curr: RefCell::new(BTreeMap::new()),
94            past: RefCell::new(BTreeMap::new()),
95            eqs: RefCell::new(Vec::new()),
96        }
97    }
98
99    fn bv_lit<'b>(&self, value: impl Into<Cow<'b, Const>>) -> SMT::BitVec {
100        self.engine.build_bitvec_lit(&*value.into())
101    }
102
103    fn bv_const(&self, prefix: &str, index: usize, suffix: &str, width: usize) -> Result<SMT::BitVec, SMT::Error> {
104        self.engine.declare_bitvec_const(&format!("{prefix}{index}{suffix}"), width)
105    }
106
107    fn bv_bind(&mut self, bv_value: SMT::BitVec, width: usize) -> Result<SMT::BitVec, SMT::Error> {
108        let bv_temp = self.engine.declare_bitvec_const(&format!("t{}", self.temp), width)?;
109        self.temp += 1;
110        self.engine.assert(self.engine.build_bitvec_eq(bv_temp.clone(), bv_value))?;
111        Ok(bv_temp)
112    }
113
114    fn bv_concat(&self, elems: impl IntoIterator<IntoIter: DoubleEndedIterator<Item = SMT::BitVec>>) -> SMT::BitVec {
115        let mut bv_result = None;
116        for elem in elems.into_iter().rev() {
117            bv_result = Some(match bv_result {
118                Some(bv_msb) => self.engine.build_concat(bv_msb, elem),
119                None => elem,
120            })
121        }
122        bv_result.expect("SMT bit vectors cannot be empty")
123    }
124
125    fn bv_is_zero(&self, bv_value: SMT::BitVec, width: usize) -> SMT::Bool {
126        self.engine.build_bitvec_eq(bv_value, self.bv_lit(Const::zero(width)))
127    }
128
129    fn bv_to_bool(&self, bv_value: SMT::BitVec) -> SMT::Bool {
130        self.engine.build_bitvec_eq(bv_value, self.bv_lit(Trit::One))
131    }
132
133    fn bool_to_bv(&self, bool_value: SMT::Bool) -> SMT::BitVec {
134        self.engine.build_bitvec_ite(bool_value, self.bv_lit(Trit::One), self.bv_lit(Trit::Zero))
135    }
136
137    fn tv_lit<'b>(&self, value: impl Into<Cow<'b, Const>>) -> SmtTritVec<SMT> {
138        let value = value.into();
139        let is_undef = Const::from_iter(value.iter().map(|t| if t == Trit::Undef { Trit::One } else { Trit::Zero }));
140        let is_one = Const::from_iter(value.iter().map(|t| if t == Trit::One { Trit::One } else { Trit::Zero }));
141        SmtTritVec { x: self.bv_lit(is_undef), y: self.bv_lit(is_one) }
142    }
143
144    fn tv_const(&self, prefix: &str, index: usize, width: usize) -> Result<SmtTritVec<SMT>, SMT::Error> {
145        Ok(SmtTritVec { x: self.bv_const(prefix, index, "x", width)?, y: self.bv_const(prefix, index, "y", width)? })
146    }
147
148    fn tv_bind(&mut self, tv_value: SmtTritVec<SMT>, width: usize) -> Result<SmtTritVec<SMT>, SMT::Error> {
149        let bv_temp_x = self.engine.declare_bitvec_const(&format!("t{}x", self.temp), width)?;
150        let bv_temp_y = self.engine.declare_bitvec_const(&format!("t{}y", self.temp), width)?;
151        self.temp += 1;
152        self.engine.assert(self.engine.build_bitvec_eq(bv_temp_x.clone(), tv_value.x))?;
153        self.engine.assert(self.engine.build_bitvec_eq(bv_temp_y.clone(), tv_value.y))?;
154        Ok(SmtTritVec { x: bv_temp_x, y: bv_temp_y })
155    }
156
157    fn tv_concat<I>(&self, elems: I) -> SmtTritVec<SMT>
158    where
159        I: IntoIterator<IntoIter: DoubleEndedIterator<Item = SmtTritVec<SMT>>>,
160    {
161        let mut bv_result = None;
162        for elem in elems.into_iter().rev() {
163            bv_result = Some(match bv_result {
164                Some((bv_x_msb, bv_y_msb)) => {
165                    (self.engine.build_concat(bv_x_msb, elem.x), self.engine.build_concat(bv_y_msb, elem.y))
166                }
167                None => (elem.x, elem.y),
168            })
169        }
170        bv_result.map(|(x, y)| SmtTritVec { x, y }).expect("SMT bit vectors cannot be empty")
171    }
172
173    fn tv_extract(&self, index_msb: usize, index_lsb: usize, tv: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
174        SmtTritVec {
175            x: self.engine.build_extract(index_msb, index_lsb, tv.x),
176            y: self.engine.build_extract(index_msb, index_lsb, tv.y),
177        }
178    }
179
180    fn tv_is0(&self, tv: SmtTritVec<SMT>) -> SMT::BitVec {
181        self.engine.build_bvand(self.engine.build_bvnot(tv.y), self.engine.build_bvnot(tv.x))
182    }
183
184    fn tv_is1(&self, tv: SmtTritVec<SMT>) -> SMT::BitVec {
185        self.engine.build_bvand(tv.y, self.engine.build_bvnot(tv.x))
186    }
187
188    fn tv_eq(&self, tv_a: SmtTritVec<SMT>, tv_b: SmtTritVec<SMT>) -> SMT::Bool {
189        let x_eq = self.engine.build_bitvec_eq(tv_a.x, tv_b.x);
190        let y_eq = self.engine.build_bitvec_eq(tv_a.y, tv_b.y);
191        self.engine.build_and(&[x_eq, y_eq])
192    }
193
194    fn tv_refines(&self, from: SmtTritVec<SMT>, to: SmtTritVec<SMT>) -> SMT::Bool {
195        let from_def = self.engine.build_bvnot(from.x.clone());
196        let x_refine = self.engine.build_bitvec_eq(
197            to.x.clone(),                          // if a bit is X in `to`...
198            self.engine.build_bvand(to.x, from.x), // ... it should be X in `from`.
199        );
200        let y_refine = self.engine.build_bitvec_eq(
201            self.engine.build_bvand(from_def.clone(), from.y), // if a bit was 0/1 in `from`...
202            self.engine.build_bvand(from_def.clone(), to.y),   // ... it should be that same value in `to`.
203        );
204        self.engine.build_and(&[x_refine, y_refine])
205    }
206
207    fn tv_not(&self, tv_a: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
208        SmtTritVec { y: self.engine.build_bvnot(tv_a.y), x: tv_a.x }
209    }
210
211    fn tv_and(
212        &mut self,
213        tv_a: SmtTritVec<SMT>,
214        tv_b: SmtTritVec<SMT>,
215        width: usize,
216    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
217        //   0 1 X
218        // 0 0 0 0
219        // 1 0 1 X
220        // X 0 X X
221        let (tv_a, tv_b) = (self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
222        let bv_arg1_is0 = self.tv_is0(tv_a.clone());
223        let bv_arg2_is0 = self.tv_is0(tv_b.clone());
224        Ok(SmtTritVec {
225            y: self.engine.build_bvand(tv_a.y, tv_b.y),
226            x: self.engine.build_bvand(
227                self.engine.build_bvor(tv_a.x, tv_b.x),
228                self.engine.build_bvnot(self.engine.build_bvor(bv_arg1_is0, bv_arg2_is0)),
229            ),
230        })
231    }
232
233    fn tv_or(
234        &mut self,
235        tv_a: SmtTritVec<SMT>,
236        tv_b: SmtTritVec<SMT>,
237        width: usize,
238    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
239        //   0 1 X
240        // 0 0 1 X
241        // 1 1 1 1
242        // X X 1 1
243        let (tv_a, tv_b) = (self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
244        let bv_arg1_is1 = self.tv_is1(tv_a.clone());
245        let bv_arg2_is1 = self.tv_is1(tv_b.clone());
246        Ok(SmtTritVec {
247            y: self.engine.build_bvor(tv_a.y, tv_b.y),
248            x: self.engine.build_bvand(
249                self.engine.build_bvor(tv_a.x, tv_b.x),
250                self.engine.build_bvnot(self.engine.build_bvor(bv_arg1_is1, bv_arg2_is1)),
251            ),
252        })
253    }
254
255    fn tv_xor(&self, tv_a: SmtTritVec<SMT>, tv_b: SmtTritVec<SMT>) -> SmtTritVec<SMT> {
256        //   0 1 X
257        // 0 0 1 X
258        // 1 1 0 X
259        // X X X X
260        SmtTritVec { y: self.engine.build_bvxor(tv_a.y, tv_b.y), x: self.engine.build_bvor(tv_a.x, tv_b.x) }
261    }
262
263    fn tv_mux(
264        &mut self,
265        tv_s: SmtTritVec<SMT>,
266        tv_a: SmtTritVec<SMT>,
267        tv_b: SmtTritVec<SMT>,
268        width: usize,
269    ) -> Result<SmtTritVec<SMT>, SMT::Error> {
270        // S A B O    S A B O    S A B O
271        // 0 ? 0 0    1 0 ? 0    X 0 0 0
272        // 0 ? 1 1    1 1 ? 1    X 1 1 1
273        // 0 ? X X    1 X ? X    X ? ? X
274        let (tv_s, tv_a, tv_b) = (self.tv_bind(tv_s, 1)?, self.tv_bind(tv_a, width)?, self.tv_bind(tv_b, width)?);
275        let (bv_s_is0, bv_s_is1) = (self.tv_is0(tv_s.clone()), self.tv_is1(tv_s));
276        let (bool_s_is0, bool_s_is1) = (self.bv_to_bool(bv_s_is0), self.bv_to_bool(bv_s_is1));
277        let bv_x_sx = self.engine.build_bvxor(tv_a.y.clone(), tv_b.y.clone());
278        let bv_x_sx = self.engine.build_bvor(bv_x_sx, tv_a.x.clone());
279        let bv_x_sx = self.engine.build_bvor(bv_x_sx, tv_b.x.clone());
280        let bv_x_s1 = self.engine.build_bitvec_ite(bool_s_is1.clone(), tv_a.x, bv_x_sx);
281        let bv_x_s0 = self.engine.build_bitvec_ite(bool_s_is0.clone(), tv_b.x, bv_x_s1);
282        let bv_y_sx = self.engine.build_bvand(tv_a.y.clone(), tv_b.y.clone());
283        let bv_y_s1 = self.engine.build_bitvec_ite(bool_s_is1, tv_a.y, bv_y_sx);
284        let bv_y_s0 = self.engine.build_bitvec_ite(bool_s_is0, tv_b.y, bv_y_s1);
285        Ok(SmtTritVec { y: bv_y_s0, x: bv_x_s0 })
286    }
287
288    fn curr_net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
289        match net.as_cell_index() {
290            Err(trit) => Ok(self.tv_lit(trit)),
291            Ok(cell_index) => {
292                let tv_net = self.tv_const("n", cell_index, 1)?;
293                self.curr.borrow_mut().insert(net, tv_net.clone());
294                Ok(tv_net)
295            }
296        }
297    }
298
299    fn curr_value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
300        Ok(self.tv_concat(value.iter().map(|net| self.curr_net(net)).collect::<Result<Vec<_>, _>>()?.into_iter()))
301    }
302
303    fn past_net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
304        match net.as_cell_index() {
305            Err(trit) => Ok(self.tv_lit(trit)),
306            Ok(cell_index) => {
307                let tv_net = self.tv_const("p", cell_index, 1)?;
308                self.curr.borrow_mut().insert(net, tv_net.clone());
309                Ok(tv_net)
310            }
311        }
312    }
313
314    fn past_value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
315        Ok(self.tv_concat(value.iter().map(|net| self.past_net(net)).collect::<Result<Vec<_>, _>>()?.into_iter()))
316    }
317
318    fn net(&self, net: Net) -> Result<SmtTritVec<SMT>, SMT::Error> {
319        if let Ok(index) = net.as_cell_index() {
320            if !self.design.is_valid_cell_index(index) {
321                return self.curr_net(net); // FIXME: is this even sound?
322            }
323        }
324        match self.design.find_cell(net) {
325            Ok((cell_ref, _offset)) if matches!(&*cell_ref.get(), Cell::Dff(_)) => self.past_net(net),
326            _ => self.curr_net(net),
327        }
328    }
329
330    fn value(&self, value: &Value) -> Result<SmtTritVec<SMT>, SMT::Error> {
331        Ok(self.tv_concat(value.iter().map(|net| self.net(net)).collect::<Result<Vec<_>, _>>()?.into_iter()))
332    }
333
334    fn invert_net(&self, control_net: ControlNet, tv_net: SmtTritVec<SMT>) -> Result<SmtTritVec<SMT>, SMT::Error> {
335        let active = if control_net.is_positive() { Const::ones(1) } else { Const::zero(1) };
336        Ok(SmtTritVec { y: self.engine.build_bvcomp(tv_net.y, self.engine.build_bitvec_lit(&active)), x: tv_net.x })
337    }
338
339    fn control_net(&self, control_net: ControlNet) -> Result<SmtTritVec<SMT>, SMT::Error> {
340        self.invert_net(control_net, self.net(control_net.net())?)
341    }
342
343    fn clock_net(&mut self, control_net: ControlNet) -> Result<SmtTritVec<SMT>, SMT::Error> {
344        self.tv_and(
345            self.invert_net(control_net, self.curr_net(control_net.net())?)?,
346            self.tv_not(self.invert_net(control_net, self.past_net(control_net.net())?)?),
347            1,
348        )
349    }
350
351    fn cell(&mut self, output: &Value, cell: &Cell) -> Result<SmtTritVec<SMT>, SMT::Error> {
352        let prepare_shift = |value: &Value, amount: &Value, stride: u32, signed: bool| -> Result<_, SMT::Error> {
353            let width = value.len().max(amount.len() + stride.max(1).ilog2() as usize + 1);
354            let value_ext = if signed { value.sext(width) } else { value.zext(width) };
355            let amount_ext = amount.zext(width);
356            let (tv_value_ext, tv_amount_ext) = (self.value(&value_ext)?, self.value(&amount_ext)?);
357            let bv_stride = self.engine.build_bitvec_lit(&Const::from_uint(stride as u128, width));
358            let bv_amount_mul = self.engine.build_bvmul(tv_amount_ext.y.clone(), bv_stride);
359            Ok((width, tv_value_ext, tv_amount_ext, bv_amount_mul))
360        };
361
362        let bv_cell = match &*cell {
363            Cell::Buf(a) => self.value(a)?,
364            Cell::Not(a) => self.tv_not(self.value(a)?),
365            Cell::And(a, b) => self.tv_and(self.value(a)?, self.value(b)?, output.len())?,
366            Cell::Or(a, b) => self.tv_or(self.value(a)?, self.value(b)?, output.len())?,
367            Cell::Xor(a, b) => self.tv_xor(self.value(a)?, self.value(b)?),
368            Cell::Mux(s, a, b) => self.tv_mux(self.net(*s)?, self.value(a)?, self.value(b)?, output.len())?,
369            Cell::Adc(a, b, c) => {
370                if a.is_empty() {
371                    self.net(*c)?
372                } else {
373                    let (tv_a, tv_b, tv_c) = (self.value(a)?, self.value(b)?, self.net(*c)?);
374                    let (tv_a, tv_b) = (self.tv_bind(tv_a, a.len())?, self.tv_bind(tv_b, b.len())?);
375                    let mut bv_x = vec![];
376                    let mut bv_carry_x = tv_c.x;
377                    for index in 0..a.len() {
378                        bv_carry_x = self.bv_bind(
379                            self.engine.build_bvor(
380                                bv_carry_x,
381                                self.engine.build_bvor(
382                                    self.engine.build_extract(index, index, tv_a.x.clone()),
383                                    self.engine.build_extract(index, index, tv_b.x.clone()),
384                                ),
385                            ),
386                            1,
387                        )?;
388                        bv_x.push(bv_carry_x.clone());
389                    }
390                    bv_x.push(bv_carry_x.clone());
391                    SmtTritVec {
392                        x: self.bv_concat(bv_x),
393                        y: self.engine.build_bvadd(
394                            self.engine.build_bvadd(
395                                self.engine.build_concat(self.bv_lit(Trit::Zero), tv_a.y),
396                                self.engine.build_concat(self.bv_lit(Trit::Zero), tv_b.y),
397                            ),
398                            self.engine.build_concat(self.bv_lit(Const::zero(a.len())), tv_c.y),
399                        ),
400                    }
401                }
402            }
403            Cell::Aig(a, b) => self.tv_and(self.control_net(*a)?, self.control_net(*b)?, 1)?,
404            Cell::Eq(a, b) => {
405                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
406                let bv_a_xor_b = self.engine.build_bvxor(tv_a.y.clone(), tv_b.y.clone());
407                let bv_unequal = self.engine.build_bvand(
408                    bv_a_xor_b,
409                    self.engine.build_bvnot(self.engine.build_bvor(tv_a.x.clone(), tv_b.x.clone())),
410                );
411                let bool_any_unequal = self.engine.build_not(self.bv_is_zero(bv_unequal, a.len()));
412                SmtTritVec {
413                    x: self.bool_to_bv(self.engine.build_bool_ite(
414                        bool_any_unequal,
415                        self.engine.build_bool_lit(false),
416                        self.engine.build_not(self.engine.build_and(&[
417                            self.bv_is_zero(tv_a.x.clone(), a.len()),
418                            self.bv_is_zero(tv_b.x.clone(), b.len()),
419                        ])),
420                    )),
421                    y: self.engine.build_bvcomp(tv_a.y, tv_b.y),
422                }
423            }
424            Cell::ULt(a, b) => {
425                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
426                SmtTritVec {
427                    x: self.bool_to_bv(self.engine.build_not(
428                        self.engine.build_and(&[self.bv_is_zero(tv_a.x, a.len()), self.bv_is_zero(tv_b.x, b.len())]),
429                    )),
430                    y: self.bool_to_bv(self.engine.build_bvult(tv_a.y, tv_b.y)),
431                }
432            }
433            Cell::SLt(a, b) => {
434                let (tv_a, tv_b) = (self.value(a)?, self.value(b)?);
435                SmtTritVec {
436                    x: self.bool_to_bv(self.engine.build_not(
437                        self.engine.build_and(&[self.bv_is_zero(tv_a.x, a.len()), self.bv_is_zero(tv_b.x, b.len())]),
438                    )),
439                    y: self.bool_to_bv(self.engine.build_bvslt(tv_a.y, tv_b.y)),
440                }
441            }
442            Cell::Shl(value, amount, stride) => {
443                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, false)?;
444                let tv_result = SmtTritVec {
445                    x: self.engine.build_bitvec_ite(
446                        self.bv_is_zero(tv_amount_ext.x, width),
447                        self.engine.build_bvshl(tv_value_ext.x, bv_amount_mul.clone()),
448                        self.engine.build_bitvec_lit(&Const::ones(width)),
449                    ),
450                    y: self.engine.build_bvshl(tv_value_ext.y, bv_amount_mul),
451                };
452                self.tv_extract(value.len() - 1, 0, tv_result)
453            }
454            Cell::UShr(value, amount, stride) => {
455                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, false)?;
456                let tv_result = SmtTritVec {
457                    x: self.engine.build_bitvec_ite(
458                        self.bv_is_zero(tv_amount_ext.x, width),
459                        self.engine.build_bvlshr(tv_value_ext.x, bv_amount_mul.clone()),
460                        self.engine.build_bitvec_lit(&Const::ones(width)),
461                    ),
462                    y: self.engine.build_bvlshr(tv_value_ext.y, bv_amount_mul),
463                };
464                self.tv_extract(value.len() - 1, 0, tv_result)
465            }
466            Cell::SShr(value, amount, stride) => {
467                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) = prepare_shift(value, amount, *stride, true)?;
468                let tv_result = SmtTritVec {
469                    x: self.engine.build_bitvec_ite(
470                        self.bv_is_zero(tv_amount_ext.x, width),
471                        self.engine.build_bvashr(tv_value_ext.x, bv_amount_mul.clone()),
472                        self.engine.build_bitvec_lit(&Const::ones(width)),
473                    ),
474                    y: self.engine.build_bvashr(tv_value_ext.y, bv_amount_mul),
475                };
476                self.tv_extract(value.len() - 1, 0, tv_result)
477            }
478            Cell::XShr(value, amount, stride) => {
479                let (width, tv_value_ext, tv_amount_ext, bv_amount_mul) =
480                    prepare_shift(&value.concat(Value::undef(1)), amount, *stride, true)?;
481                let tv_result = SmtTritVec {
482                    x: self.engine.build_bitvec_ite(
483                        self.bv_is_zero(tv_amount_ext.x, width),
484                        self.engine.build_bvashr(tv_value_ext.x, bv_amount_mul.clone()),
485                        self.engine.build_bitvec_lit(&Const::ones(width)),
486                    ),
487                    y: self.engine.build_bvlshr(tv_value_ext.y, bv_amount_mul),
488                };
489                self.tv_extract(value.len() - 1, 0, tv_result)
490            }
491            Cell::Mul(..) => unimplemented!("lowering of mul to SMT-LIB is not implemented"),
492            Cell::UDiv(..) => unimplemented!("lowering of udiv to SMT-LIB is not implemented"),
493            Cell::UMod(..) => unimplemented!("lowering of umod to SMT-LIB is not implemented"),
494            Cell::SDivTrunc(..) => unimplemented!("lowering of sdiv_trunc to SMT-LIB is not implemented"),
495            Cell::SDivFloor(..) => unimplemented!("lowering of sdiv_floor to SMT-LIB is not implemented"),
496            Cell::SModTrunc(..) => unimplemented!("lowering of smod_trunc to SMT-LIB is not implemented"),
497            Cell::SModFloor(..) => unimplemented!("lowering of smod_floor to SMT-LIB is not implemented"),
498            Cell::Match(MatchCell { value, enable, patterns }) => {
499                let mut tv_matches = vec![];
500                for alternates in patterns {
501                    let mut tv_sum = self.tv_lit(Trit::Zero);
502                    for pattern in alternates {
503                        let mut tv_prd = self.tv_lit(Trit::One);
504                        for (index, mask) in pattern.iter().enumerate().filter(|(_index, mask)| *mask != Trit::Undef) {
505                            let tv_eq = self.tv_not(self.tv_xor(self.net(value[index])?, self.net(mask.into())?));
506                            tv_prd = self.tv_and(tv_prd, tv_eq, 1)?;
507                        }
508                        tv_sum = self.tv_or(tv_sum, tv_prd, 1)?;
509                    }
510                    tv_matches.push(tv_sum);
511                }
512                let tv_matches = self.tv_bind(self.tv_concat(tv_matches), patterns.len())?;
513                let tv_enable = self.net(*enable)?;
514                let tv_all_cold = self.tv_lit(&Const::zero(patterns.len()));
515                let mut tv_result = tv_all_cold.clone();
516                for index in (0..patterns.len()).rev() {
517                    tv_result = self.tv_mux(
518                        self.tv_extract(index, index, tv_matches.clone()),
519                        self.tv_lit(&Const::one_hot(patterns.len(), index)),
520                        tv_result,
521                        patterns.len(),
522                    )?;
523                }
524                self.tv_mux(tv_enable.clone(), tv_result, tv_all_cold, patterns.len())?
525            }
526            Cell::Assign(AssignCell { value, enable, update, offset }) => self.tv_mux(
527                self.net(*enable)?,
528                self.value(&{
529                    let mut nets = Vec::from_iter(value.iter());
530                    nets[*offset..(*offset + update.len())].copy_from_slice(&update[..]);
531                    Value::from_iter(nets)
532                })?,
533                self.value(value)?,
534                output.len(),
535            )?,
536            Cell::Dff(flip_flop) => {
537                let mut data = self.value(&flip_flop.data)?;
538                let clear = self.control_net(flip_flop.clear)?;
539                let reset = self.control_net(flip_flop.reset)?;
540                let enable = self.control_net(flip_flop.enable)?;
541                if flip_flop.reset_over_enable {
542                    data = self.tv_mux(enable, data, self.past_value(&output)?, output.len())?;
543                    data = self.tv_mux(reset, self.tv_lit(&flip_flop.reset_value), data, output.len())?;
544                } else {
545                    data = self.tv_mux(reset, self.tv_lit(&flip_flop.reset_value), data, output.len())?;
546                    data = self.tv_mux(enable, data, self.past_value(&output)?, output.len())?;
547                }
548                let active_edge = self.clock_net(flip_flop.clock)?;
549                let value = self.tv_mux(active_edge, data, self.past_value(&output)?, output.len())?;
550                if flip_flop.has_clear() {
551                    self.tv_mux(clear, self.tv_lit(&flip_flop.clear_value), value, output.len())?
552                } else {
553                    value
554                }
555            }
556            Cell::Memory(_memory) => unimplemented!("memories are not lowered to SMT-LIB yet"),
557            Cell::IoBuf(_io_buffer) => self.value(output)?, // i/en/o treated as POs/PIs
558            Cell::Target(_target_cell) => unimplemented!("target cells cannot be lowered to SMT-LIB yet"),
559            Cell::Other(_) => unreachable!("instances cannot be lowered to SMT-LIB"),
560            Cell::Input(..) | Cell::Output(..) | Cell::Name(..) | Cell::Debug(..) => unreachable!(),
561        };
562
563        Ok(bv_cell)
564    }
565
566    pub fn add_cell(&mut self, output: &Value, cell: &Cell) -> Result<(), SMT::Error> {
567        // Declare the nets used by the cell so that it is present in the counterexample even if unused.
568        if let Cell::Input(..) = cell {
569            self.curr_value(output)?;
570            return Ok(());
571        }
572        let tv_cell = self.cell(output, cell)?;
573        self.engine.assert(self.tv_eq(self.curr_value(output)?, tv_cell))
574    }
575
576    pub fn replace_cell(&mut self, output: &Value, old_cell: &Cell, new_cell: &Cell) -> Result<(), SMT::Error> {
577        self.add_cell(output, old_cell)?;
578        let tv_new_cell = self.cell(output, new_cell)?;
579        self.eqs.borrow_mut().push(self.tv_refines(self.curr_value(output)?, tv_new_cell));
580        Ok(())
581    }
582
583    pub fn replace_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
584        self.eqs.borrow_mut().push(self.tv_refines(self.curr_net(from_net)?, self.net(to_net)?));
585        Ok(())
586    }
587
588    pub fn replace_void_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
589        self.engine.assert(self.tv_eq(self.curr_net(from_net)?, self.net(to_net)?))
590    }
591
592    pub fn replace_dff_net(&mut self, from_net: Net, to_net: Net) -> Result<(), SMT::Error> {
593        // Essentially a single induction step.
594        self.engine.assert(self.tv_eq(self.past_net(from_net)?, self.past_net(to_net)?))?;
595        self.eqs.borrow_mut().push(self.tv_refines(self.curr_net(from_net)?, self.curr_net(to_net)?));
596        Ok(())
597    }
598
599    pub fn check(&mut self) -> Result<Option<SmtExample>, SMT::Error> {
600        if self.eqs.borrow().is_empty() {
601            return Ok(None);
602        }
603        let not_and_eqs = self.engine.build_not(self.engine.build_and(&self.eqs.borrow()[..]));
604        self.engine.assert(not_and_eqs)?;
605        match self.engine.check()? {
606            SmtResponse::Unknown => panic!("SMT solver returned unknown"),
607            SmtResponse::Unsat => Ok(None),
608            SmtResponse::Sat => {
609                let get_trit = |tv_net: &SmtTritVec<SMT>| -> Result<Trit, SMT::Error> {
610                    if self.engine.get_bitvec(&tv_net.x)?[0] == Trit::One {
611                        Ok(Trit::Undef)
612                    } else {
613                        Ok(self.engine.get_bitvec(&tv_net.y)?[0])
614                    }
615                };
616                let (mut curr, mut past) = (BTreeMap::new(), BTreeMap::new());
617                for (net, tv_net) in self.curr.borrow().iter() {
618                    curr.insert(*net, get_trit(tv_net)?);
619                }
620                for (net, tv_net) in self.past.borrow().iter() {
621                    past.insert(*net, get_trit(tv_net)?);
622                }
623                Ok(Some(SmtExample { curr, past }))
624            }
625        }
626    }
627}
628
629#[derive(Debug, Clone, PartialEq, Eq)]
630pub struct SmtExample {
631    curr: BTreeMap<Net, Trit>,
632    past: BTreeMap<Net, Trit>,
633}
634
635impl SmtExample {
636    pub fn get_net(&self, net: Net) -> Option<Trit> {
637        self.curr.get(&net).cloned()
638    }
639
640    pub fn get_value<'a>(&self, value: impl Into<Cow<'a, Value>>) -> Option<Const> {
641        let mut result = Const::new();
642        for net in &*value.into() {
643            result.push(self.get_net(net)?);
644        }
645        Some(result)
646    }
647
648    pub fn get_past_net(&self, net: Net) -> Option<Trit> {
649        self.past.get(&net).cloned()
650    }
651
652    pub fn get_past_value<'a>(&self, value: impl Into<Cow<'a, Value>>) -> Option<Const> {
653        let mut result = Const::new();
654        for net in &*value.into() {
655            result.push(self.get_past_net(net)?);
656        }
657        Some(result)
658    }
659}