/*++ Copyright (c) 2017 Microsoft Corporation, Simon Cruanes Module Name: recfun_decl_plugin.h Abstract: Declaration and definition of (potentially recursive) functions Author: Simon Cruanes 2017-11 Revision History: --*/ #pragma once #include "ast/ast.h" #include "ast/ast_pp.h" #include "util/obj_hashtable.h" namespace recfun { class case_def; // cases; ast_manager & m; symbol m_name; // def_map; typedef obj_map case_def_map; mutable scoped_ptr m_util; def_map m_defs; // function->def case_def_map m_case_defs; // case_pred->def ast_manager & m() { return *m_manager; } void compute_scores(expr* e, obj_map& scores); public: plugin(); ~plugin() override; void finalize() override; util & u() const; // build or return util bool is_fully_interp(sort * s) const override { return false; } // might depend on unin sorts decl_plugin * mk_fresh() override { return alloc(plugin); } sort * mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) override { UNREACHABLE(); return nullptr; } func_decl * mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) override; promise_def mk_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated = false); promise_def ensure_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated = false); void set_definition(replace& r, promise_def & d, unsigned n_vars, var * const * vars, expr * rhs); def* mk_def(replace& subst, symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs); void erase_def(func_decl* f); bool has_def(func_decl* f) const { return m_defs.contains(f); } bool has_defs() const; def const& get_def(func_decl* f) const { return *(m_defs[f]); } promise_def get_promise_def(func_decl* f) const { return promise_def(&u(), m_defs[f]); } def& get_def(func_decl* f) { return *(m_defs[f]); } bool has_case_def(func_decl* f) const { return m_case_defs.contains(f); } case_def& get_case_def(func_decl* f) { SASSERT(has_case_def(f)); return *(m_case_defs[f]); } func_decl_ref_vector get_rec_funs() { func_decl_ref_vector result(m()); for (auto& kv : m_defs) result.push_back(kv.m_key); return result; } expr_ref redirect_ite(replace& subst, unsigned n, var * const* vars, expr * e); }; } // Various utils for recursive functions class util { friend class decl::plugin; ast_manager & m_manager; family_id m_fid; decl::plugin * m_plugin; bool compute_is_immediate(expr * rhs); void set_definition(replace& r, promise_def & d, unsigned n_vars, var * const * vars, expr * rhs); public: util(ast_manager &m); ~util(); ast_manager & m() { return m_manager; } family_id get_family_id() const { return m_fid; } decl::plugin& get_plugin() { return *m_plugin; } bool is_case_pred(expr * e) const { return is_app_of(e, m_fid, OP_FUN_CASE_PRED); } bool is_defined(expr * e) const { return is_app_of(e, m_fid, OP_FUN_DEFINED); } bool is_defined(func_decl* f) const { return is_decl_of(f, m_fid, OP_FUN_DEFINED); } bool is_generated(func_decl* f) const { return is_defined(f) && f->get_parameter(0).get_int() == 1; } bool is_num_rounds(expr * e) const { return is_app_of(e, m_fid, OP_NUM_ROUNDS); } bool owns_app(app * e) const { return e->get_family_id() == m_fid; } //has_defs(); } //has_def(f); } def& get_def(func_decl* f) { SASSERT(has_def(f)); return m_plugin->get_def(f); } case_def& get_case_def(expr* e) { SASSERT(is_case_pred(e)); return m_plugin->get_case_def(to_app(e)->get_decl()); } app* mk_fun_defined(def const & d, unsigned n_args, expr * const * args) { return m().mk_app(d.get_decl(), n_args, args); } app* mk_fun_defined(def const & d, ptr_vector const & args) { return mk_fun_defined(d, args.size(), args.data()); } app* mk_fun_defined(def const & d, expr_ref_vector const & args) { return mk_fun_defined(d, args.size(), args.data()); } func_decl_ref_vector get_rec_funs() { return m_plugin->get_rec_funs(); } app_ref mk_num_rounds_pred(unsigned d); }; // one case-expansion of `f(t1...tn)` struct case_expansion { app_ref m_lhs; // the term to expand recfun::def * m_def; expr_ref_vector m_args; case_expansion(recfun::util& u, app * n); case_expansion(case_expansion const & from); case_expansion(case_expansion && from); std::ostream& display(std::ostream& out) const; }; inline std::ostream& operator<<(std::ostream& out, case_expansion const & e) { return e.display(out); } // one body-expansion of `f(t1...tn)` using a `C_f_i(t1...tn)` struct body_expansion { app_ref m_pred; recfun::case_def const * m_cdef; expr_ref_vector m_args; body_expansion(recfun::util& u, app * n) : m_pred(n, u.m()), m_cdef(nullptr), m_args(u.m()) { m_cdef = &u.get_case_def(n); m_args.append(n->get_num_args(), n->get_args()); } body_expansion(app_ref & pred, recfun::case_def const & d, expr_ref_vector & args) : m_pred(pred), m_cdef(&d), m_args(args) {} body_expansion(body_expansion const & from): m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(from.m_args) {} body_expansion(body_expansion && from) : m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(std::move(from.m_args)) {} std::ostream& display(std::ostream& out) const; }; inline std::ostream& operator<<(std::ostream& out, body_expansion const& e) { return e.display(out); } struct propagation_item { case_expansion* m_case { nullptr }; body_expansion* m_body { nullptr }; expr_ref_vector* m_core { nullptr }; expr* m_guard { nullptr }; ~propagation_item() { dealloc(m_case); dealloc(m_body); dealloc(m_core); } propagation_item(expr* guard): m_guard(guard) {} propagation_item(expr_ref_vector const& core): m_core(alloc(expr_ref_vector, core)) { } propagation_item(body_expansion* b): m_body(b) {} propagation_item(case_expansion* c): m_case(c) {} bool is_guard() const { return m_guard != nullptr; } bool is_core() const { return m_core != nullptr; } bool is_case() const { return m_case != nullptr; } bool is_body() const { return m_body != nullptr; } expr_ref_vector const& core() const { SASSERT(is_core()); return *m_core; } body_expansion & body() const { SASSERT(is_body()); return *m_body; } case_expansion & case_ex() const { SASSERT(is_case()); return *m_case; } expr* guard() const { SASSERT(is_guard()); return m_guard; } }; }