LTNjax: JAX implementation of Logic Tensor Networks¶
Logic Tensor Networks (LTN)¶
LTNjax is a neurosymbolic framework that allows the implementation of knowledge in the form of logical expressions as objective for neural networks. LTN uses a differentiable first-order logic language, called Real Logic, to incorporate machine learning and logic.
This repository contains an implementation of LTN in Jax. This project is based on the projects logictensornetworks/logictensornetworks (Tensorflow) and tommasocarraro/LTNtorch (PyTorch).
Basics¶
The framework consists of the following components:
- Constants that are tensors of any rank \(n \in [0,\infty)\), i.e. it is an element of \(\mathbb{R}\) or an element of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\) that we call feature dimensions. Constants can be trainable by putting
nnx.Paramvalues in. - Variables that consists of \(m\) individuals that are tensors of any rank. That means that an individual is an element of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\) that we call feature dimensions. So a Variable is an element of \(\mathbb{R}^{(m \times d_1 \times \dotsc \times d_n)}\). Variables can be trainable by putting
nnx.Paramvalues in. - Predicates takes tuples of Constant and Variable objects as arguments and is applied to each configuration of the variables and returns a truth-value
[0, 1]. Predicates consist of trainablennx.Moduleobjects that can be neural networks or lambda expressions. - Functions takes tuples of Constant and Variable objects as arguments and is applied to each configuration of the variables and returns a tensor \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\). Functions consist of trainable
nnx.Moduleobjects that can be neural networks or lambda expressions. - Connectives connects an arbitrary amount of Expressions to one joint Expressions, i.e. \(\land, \lor, \neg, \Rightarrow, \leftrightarrow\).
- (Masked) Forall- or Exists-Quantifiers, i.e. \(\forall, \exists\), that quantify over a set of Variables that satisfy a given mask and check whether they fulfill some given Expression.
Repository structure¶
- docs/ -- Documentation.
- examples/ -- Code examples.
- LICENSES/ -- All license files used somewhere in this project. Also see LICENSE.md.
- ltnjax/core.py -- Core system for defining constants, variables, functions, connectives and quantifiers.
- ltnjax/fuzzy_ops.py -- A collection of fuzzy logic operators defined using JAX primitives.
- scripts/ -- Scripts for testing the code.
- tests/ -- Tests.
- tutorials/ -- Tutorials written as jupyter notebooks.
Getting Started¶
Tutorials¶
tutorials/ contains some important tutorials to getting started with coding in LTN. We suggest completing the tutorials in order. The tutorials cover the following topics:
- Grounding in LTN (part 1): Real Logic, constants, predicates, functions, variables;
- Grounding in LTN (part 2): connectives and quantifiers (+ complement: choosing appropriate operators for learning);
- Learning in LTN: using satisfiability of LTN formulas as a training objective.
The tutorials are implemented using jupyter notebooks.
Training¶
To implement a trainable neural networks together with Logic Tensor Networks (LTN), the user can place a neural network as a Function/Predicate inside the LTN or put trainable values in Variables or Constants.
For that, define a nnx.Module and place modules (nnx.Module) and trainable parameters nnx.Param inside the __init__ function and the LTN inside the __call__ function.
For the first option, see train_mlp.py and for the second one see trainable_expr.py.
Examples¶
- The above example can be found in connectives.py.
- A basic example to train an MLP can be found in train_mlp.py.
- In the framework, every constant and variable can be trainable. For trainable values, just put
nnx.Paramas values and for non-trainable values, everythingArrayLike. See for example trainable_expr.py. - During training, the user should take care about gradient issues, see gradient_issues.py.
- An example for masked quantifiers can be found in masked_quantifiers.py.
- In weighted_ops.py is illustrated how to use weighted aggregation operations.
Improvements to previous LTN frameworks¶
- A broader collection of fuzzy operations, see ltnjax/fuzzy_ops.py.
- This framework allows
ltnjax.Connectiveto take arbitrarily many arguments, see examples/connectives.py. - This framework allows weighted aggregators, see examples/weighted_ops.py.
- The framework now supports empty variables, i.e. \(x \in \emptyset\). This is important as the framework may be get filtered variables that may at the end be empty.
- The framework does now allow to put
ltnjax.Constantsandltnjax.Variablesdirectly intoltnjax.Connectiveorltnjax.Quantifier. In the older versions, altnjax.Function(or Predicate) had to be applied first. - The framework allows
ltnjax.Connectiveandltnjax.Quantifierto handle tensors with feature dimensions. That is, not only truth values in \([0,1]\) (that are needed for logical formulas), but also tensors of \(\mathbb{R}^{(d_1 \times \dotsc \times d_n)}\). This could be useful for example for summing up predicates for using them as regularization.