Skip to content

LTNjax: JAX implementation of Logic Tensor Networks

Logic Tensor Networks (LTN)

LTNjax is a neurosymbolic framework that allows the implementation of knowledge in the form of logical expressions as objective for neural networks. LTN uses a differentiable first-order logic language, called Real Logic, to incorporate machine learning and logic.

This repository contains an implementation of LTN in Jax. This project is based on the projects logictensornetworks/logictensornetworks (Tensorflow) and tommasocarraro/LTNtorch (PyTorch).

Basics

The framework consists of the following components:

  • Constants that are tensors of any rank \(n \in [0,\infty)\), i.e. it is an element of \(\mathbb{R}\) or an element of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\) that we call feature dimensions. Constants can be trainable by putting nnx.Param values in.
  • Variables that consists of \(m\) individuals that are tensors of any rank. That means that an individual is an element of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\) that we call feature dimensions. So a Variable is an element of \(\mathbb{R}^{(m \times d_1 \times \dotsc \times d_n)}\). Variables can be trainable by putting nnx.Param values in.
  • Predicates takes tuples of Constant and Variable objects as arguments and is applied to each configuration of the variables and returns a truth-value [0, 1]. Predicates consist of trainable nnx.Module objects that can be neural networks or lambda expressions.
  • Functions takes tuples of Constant and Variable objects as arguments and is applied to each configuration of the variables and returns a tensor \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\). Functions consist of trainable nnx.Module objects that can be neural networks or lambda expressions.
  • Connectives connects an arbitrary amount of Expressions to one joint Expressions, i.e. \(\land, \lor, \neg, \Rightarrow, \leftrightarrow\).
  • (Masked) Forall- or Exists-Quantifiers, i.e. \(\forall, \exists\), that quantify over a set of Variables that satisfy a given mask and check whether they fulfill some given Expression.

Repository structure

Getting Started

Tutorials

tutorials/ contains some important tutorials to getting started with coding in LTN. We suggest completing the tutorials in order. The tutorials cover the following topics:

  1. Grounding in LTN (part 1): Real Logic, constants, predicates, functions, variables;
  2. Grounding in LTN (part 2): connectives and quantifiers (+ complement: choosing appropriate operators for learning);
  3. Learning in LTN: using satisfiability of LTN formulas as a training objective.

The tutorials are implemented using jupyter notebooks.

Training

To implement a trainable neural networks together with Logic Tensor Networks (LTN), the user can place a neural network as a Function/Predicate inside the LTN or put trainable values in Variables or Constants.

For that, define a nnx.Module and place modules (nnx.Module) and trainable parameters nnx.Param inside the __init__ function and the LTN inside the __call__ function.

For the first option, see train_mlp.py and for the second one see trainable_expr.py.

Examples

  • The above example can be found in connectives.py.
  • A basic example to train an MLP can be found in train_mlp.py.
  • In the framework, every constant and variable can be trainable. For trainable values, just put nnx.Param as values and for non-trainable values, everything ArrayLike. See for example trainable_expr.py.
  • During training, the user should take care about gradient issues, see gradient_issues.py.
  • An example for masked quantifiers can be found in masked_quantifiers.py.
  • In weighted_ops.py is illustrated how to use weighted aggregation operations.

Improvements to previous LTN frameworks

  • A broader collection of fuzzy operations, see ltnjax/fuzzy_ops.py.
  • This framework allows ltnjax.Connective to take arbitrarily many arguments, see examples/connectives.py.
  • This framework allows weighted aggregators, see examples/weighted_ops.py.
  • The framework now supports empty variables, i.e. \(x \in \emptyset\). This is important as the framework may be get filtered variables that may at the end be empty.
  • The framework does now allow to put ltnjax.Constants and ltnjax.Variables directly into ltnjax.Connective or ltnjax.Quantifier. In the older versions, a ltnjax.Function (or Predicate) had to be applied first.
  • The framework allows ltnjax.Connective and ltnjax.Quantifier to handle tensors with feature dimensions. That is, not only truth values in \([0,1]\) (that are needed for logical formulas), but also tensors of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\). This could be useful for example for summing up predicates for using them as regularization.