Coverage for src/ltnjax/core.py: 90%
244 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-26 11:35 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-26 11:35 +0000
1# SPDX-FileCopyrightText: 2026 German Aerospace Center (DLR)
2# SPDX-License-Identifier: MIT
3#
4from __future__ import annotations
6from collections.abc import Callable, Sequence
7import types
8from typing import Any
10from flax import nnx
12# Annotations
13from jax import Array # should be used for outputs
14import jax.numpy as jnp
15from jax.typing import ArrayLike # should be used for inputs
16import numpy as np
18import ltnjax as ltn
21VarLabel = str
24class LTNObject(nnx.Module):
25 r"""Class representing a generic LTN object.
27 An LTN object contains the results of an expression with variables
28 `free_vars`.
30 In LTNjax, LTN objects are constants, variables, and outputs of predicates,
31 formulas, functions, connectives, and quantifiers.
33 Internally this contains a tensor `value` and an ordered list `free_vars`.
34 If the expression has $n$ `free_vars` $x_1,\dotsc,x_n$, then
35 the first $n$ axes belong to these variables. The rest is reserved
36 for `feature_dims` the result. This can $[,\infty)$ many axes.
37 For each setting of the `free_vars`, the tensor will contain the
38 corresponding result. This gets more clear by looking at the example below.
40 The class extends `nnx.Module` since it may contain values.
42 Attributes:
43 value: The result of the expression. If there are $n$
44 `free_vars`, then the first $n$ axes belong to these
45 `free_vars` in the same order.
46 free_vars: The free variables that are contained in the expression.
47 `free_vars` is ordered in a way that if we have $n$
48 free_vars, the first $n$ axes of `value` belong to these
49 variables.
51 Examples:
52 We define the variables $x \in [1, 0]$ and $y \in [0, 1]$
53 and the expression $x * y$.
54 ```python
55 >>> import ltn
56 >>> x = ltn.Variable('x', [1., 0.])
57 >>> y = ltn.Variable('y', [0., 1.])
58 >>> prod = ltn.Connective(ltn.fuzzy_ops.AndProd(stable=False))
59 >>> res = prod(x, y)
60 ```
62 If we choose x to be 1. (the value at index 0) and y to be 1. (the
63 value at index 1), we get $1 * 1 = 1$.
64 ```python
65 >>> print(res.take('x', 0).take('y', 1))
66 LTNObject(value=Array(1., dtype=float32), free_vars=[])
67 ```
69 If we only choose y to be 1 (the value at index 0), we get the
70 LTN object $x * 1$. The result still depends on which x we chose.
71 Either $x=0$ and we get $x * 1=0$ or $x=1$ and we
72 get $x * 1=1$.
73 ```python
74 >>> print(res.take('y', 1))
75 LTNObject(value=Array(shape=(2,), dtype=dtype('float32')),\
76 free_vars=['x'])
77 >>> print(res.take('y', 1).value)
78 [1. 0.]
79 ```
81 If we choose no variable, then the LTN object will contain all
82 possible results:
83 ```python
84 >>> print(res)
85 LTNObject(value=Array(shape=(2, 2), dtype=dtype('float32')),\
86 free_vars=['x', 'y'])
87 >>> print(res.value)
88 [[0. 1.]
89 [0. 0.]]
90 ```
92 Note that LTN objects can also be tensors of any rank:
93 ```python
94 >>> x = ltn.Variable('x', [[[3.1415926535, 2.7182818284],\
95 [1.414213562, .6931471805]], [[1., 0.], [0., 1.]]])
96 >>> print(x)
97 Variable(value=Array(shape=(2, 2, 2), dtype=dtype('float32')),\
98 free_vars=['x'], var_label='x')
99 >>> print(x.value)
100 [[[3.1415927 2.7182817]
101 [1.4142135 0.6931472]]
103 [[1. 0. ]
104 [0. 1. ]]]
105 ```
106 Here, the first axis belongs to free variable 'x' while the remaining
107 axes are feature dimensions.
108 """
110 def __init__(
111 self,
112 value: ArrayLike,
113 free_vars: list[VarLabel],
114 trainable: bool = False,
115 ) -> None:
116 """Constructor.
118 Args:
119 value: The result of the expression. If there are $n$
120 `free_vars`, then the first $n$ axes belong to these
121 `free_vars` in the same order.
122 free_vars: List of labels of free variables that are contained in
123 the expression. `free_vars` is ordered in a way that if we
124 have $n$ free_vars, the first $n$ axes of `value`
125 belong to these variables.
126 trainable: Flag indicating whether the LTN constant is trainable
127 (embedding) or not.
128 """
129 if not trainable:
130 self.value = jnp.asarray(value, dtype=jnp.float32)
131 else:
132 # We put a MyPy ignore here, because nnx.Param behaves as an
133 # jnp.Array.
134 self.value = nnx.Param(jnp.asarray(value, dtype=jnp.float32)) # type: ignore
135 self.free_vars: list[VarLabel] = free_vars
137 def __repr__(self) -> str:
138 """Representation function.
140 Called by the repr() built-in function to compute the "official"
141 string representation of an object.
142 """
143 return (
144 f"ltn.{self.__class__.__name__}(value={self.value}, "
145 f"free_vars={self.free_vars})"
146 )
148 def _copy(self) -> LTNObject:
149 """Copy function.
151 Copy the LTN object but point to the same tensor, for
152 gradient tracking.
154 Returns:
155 Copy of this LTN object instance.
156 """
157 return LTNObject(self.value, self.free_vars)
159 def _get_axis_of_free_var(self, free_var: VarLabel) -> int:
160 """Axis of `free_var`.
162 Given a free variable `free_var`, returns the axis in attribute
163 `value` that belongs to this free variable.
165 Args:
166 free_var: Label of the free variable, whose axis we want.
168 Returns:
169 Axis that belongs to the free variable `free_var`.
171 Raises:
172 ValueError: If `free_var` does not occur in `free_vars`.
173 """
174 if free_var not in self.free_vars:
175 raise ValueError(
176 f"{free_var} is not a free variable occurring in the LTN "
177 "object."
178 )
179 return self.free_vars.index(free_var)
181 def _get_dim_of_free_var(self, free_var: VarLabel) -> int:
182 """Dimension of `free_var`.
184 Given a label `free_var`, returns the corresp. dimension.
186 Args:
187 free_var: The label of the variable.
189 Returns:
190 The dimension that corresponds to variable `free_var`.
192 Note:
193 Do not confuse `axis` with `dimension`. The first is the position
194 in the shape of `value` where the second is the value.
195 """
196 return jnp.shape(self.value)[self._get_axis_of_free_var(free_var)]
198 def shape(self) -> tuple[int, ...]:
199 """Returns the shape of the grounding of the LTN object.
201 Returns:
202 The shape of the grounding of the LTN object.
203 """
204 return self.value.shape
206 def take(self, free_var: VarLabel, indices: int | list[int]) -> LTNObject:
207 r"""Take elements along the axis that corresponds to `free_var`.
209 Args:
210 free_var: The variable we want to fix.
211 indices: This is either an integer or a list of integers from
212 $\{0, \dotsc, n-1\}$ where $n$ is the dimension
213 of variable `free_var`. This parameter has the same effect as
214 `jnp.array(indices)` in `jax.numpy.take`.
216 Returns:
217 The LTN object containing the elements along the axis that
218 corresponds to `free_var`.
220 Raises:
221 ValueError: If `indices` is neither int nor list of ints.
222 """
223 indices_arr: Array = jnp.array(indices)
224 remaining_free_vars = [v for v in self.free_vars if v != free_var]
225 if jnp.ndim(indices_arr) not in [0, 1]:
226 raise ValueError("Give a single indice or a list of indices.")
228 result = self._copy()
229 result.value = jnp.take(
230 self.value, indices_arr, axis=self._get_axis_of_free_var(free_var)
231 )
232 result.free_vars = remaining_free_vars
233 return result
236class Variable(LTNObject):
237 r"""Class representing an LTN variable.
239 A variable $x$ that can take only a <b>finite number</b>
240 $n$ of tensors of any rank, that means with or without feature
241 dimensions.
243 Without feature dimensions:
244 $x \in \mathcal{R}^n$.
246 With feature dimensions $d_1 \times \dotsc \times d_m$:
247 $x \in \mathcal{R}^(n \times d_1 \times \dotsc \times d_m)$.
249 Attributes:
250 var_label: The name of the variable.
251 value: The array describes a batch of individuals;
252 The first axis describes the number of individuals and
253 the optional remaining axes are the `feature_dims` of the
254 individuals. The `feature_dims` of each individual must be equals.
255 free_vars: The free variables that are contained in the expression.
256 `free_vars` is ordered in a way that if we have $n$
257 free_vars, the first n axes of `value` belong to these variables.
258 trainable: Flag indicating whether the LTN constant is trainable
259 (embedding) or not.
261 Raises:
262 ValueError: If `var_label` starts with one of the reserved strings
263 `diag` or `_flat`.
265 Note:
266 - The first dimension $n$ of an LTN variable is associated with
267 the number of individuals in the variable, while the other $d$
268 exes are associated with the features of the individuals;
269 - If the Variable should contain truth values in `[0., 1.]` that are
270 updated during training, ensure that they will stay in that interval.
271 This could be done by using sigmoid or tanh operations on the values.
272 Avoid using `jax.numpy.clip(x, 0., 1.)` during training as this will
273 yield to gradient issues.
274 """
276 def __init__(
277 self, var_label: VarLabel, individuals: Any, trainable: bool = False
278 ) -> None:
279 """Constructor.
281 Args:
282 var_label: The name of the variable.
283 individuals: An object that is convertible to an array. This
284 includes JAX arrays, NumPy arrays, Python scalars, Python
285 collections like lists and tuples, objects with an
286 ``__array__`` method, and objects supporting the Python buffer
287 protocol.
289 The array describes a batch of individuals;
290 The first axis describes the number of individuals and
291 the optional remaining axes are the `feature_dims` of the
292 individuals. The `feature_dims` of each individual must be
293 equals.
294 trainable: Flag indicating whether the LTN constant is trainable
295 (embedding) or not.
296 """
297 # Check inputs
298 for reserved in ["diag", "_flat"]:
299 if var_label.startswith(reserved): 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true
300 raise ValueError(
301 f"Labels starting with {reserved} are reserved."
302 )
304 # This is necessary as the input value could be scalars, lists,
305 # numpy arrays, etc.
306 value = jnp.asarray(individuals, dtype=jnp.float32)
308 # Ensure batch_dims for 0D-arrays
309 if jnp.ndim(value) == 0:
310 value = value.reshape(1)
311 free_vars = [var_label]
312 super().__init__(value, free_vars=free_vars, trainable=trainable)
313 self.label = var_label
315 def __repr__(self) -> str:
316 """Representation function.
318 Called by the repr() built-in function to compute the "official"
319 string representation of an object.
320 """
321 return (
322 f"ltn.{self.__class__.__name__}(var_label={self.label}, "
323 f"value={self.value}, free_vars={self.free_vars})"
324 )
327class Constant(LTNObject):
328 """The class representing constants.
330 A constant can be a tensor of any rank.
332 Attributes:
333 value: Value of the constant that is a array of an arbitrary rank.
334 free_vars: The free variables that are contained in the expression.
335 `free_vars` is ordered in a way that if we have $n$
336 `free_vars`, the first $n$ axes of `value` belong to these
337 variables.
338 trainable: Flag indicating whether the LTN constant is trainable
339 (embedding) or not.
340 """
342 def __init__(self, value: Any, trainable: bool = False) -> None:
343 """Constructor.
345 Args:
346 value: An object that is convertible to an array. This includes JAX
347 arrays, NumPy arrays, Python scalars, Python collections like
348 lists and tuples, objects with an ``__array__`` method, and
349 objects supporting the Python buffer protocol.
350 trainable: Flag indicating whether the LTN constant is trainable
351 (embedding) or not.
352 """
353 # This is necessary as the input value could be scalars, lists,
354 # numpy arrays, etc.
355 value = jnp.asarray(value, dtype=jnp.float32)
357 free_vars: list[VarLabel] = []
358 super().__init__(value, free_vars=free_vars, trainable=trainable)
360 def __repr__(self) -> str:
361 """Representation function.
363 Called by the repr() built-in function to compute the "official"
364 string representation of an object.
365 """
366 return (
367 f"ltn.{self.__class__.__name__}(value={self.value}, "
368 f"free_vars={self.free_vars})"
369 )
372def _flatten_free_dims(
373 exprs: list[LTNObject], in_place: bool = False
374) -> list[LTNObject]:
375 r"""Flattens `free_dims` of `exprs`.
377 Flattens the $[0,\infty)$ many free dimensions in `exprs` into
378 one `_flat_` axis. `free_dims` and the `feature_dims` may be empty. If
379 `free_dims` is empty, the `_flat_`-axis will be of size $1$. That is,
380 for each LTN object in `exprs` if `expr.value` has shape (5,10,2,3) and
381 `expr.free_vars=['x','y']`, we will get an output of shape (50,2,3) with
382 `expr.free_vars = ['_flat_x_y']`.
384 Args:
385 exprs: List of [LTNObject][ltnjax.core.LTNObject].
386 in_place: (default=False) Boolean that decides whether we perform the
387 operation on a new copy or on the same LTN objects.
389 Returns:
390 List of the flattened LTN objects. The output shape will be the
391 `_flat_`-axis plus the $[0,\infty)$ many feature_dims.
393 Note:
394 - Do not confuse `free_dims` with `feature_dims`. The `free_dims` are
395 these that are associated with `free_vars`.
396 """
397 if not in_place: 397 ↛ 399line 397 didn't jump to line 399 because the condition on line 397 was always true
398 exprs = [expr._copy() for expr in exprs]
399 for expr in exprs:
400 non_var_shape = expr.value.shape[len(expr.free_vars) :]
401 if expr.value.size == 0:
402 # The case that we have an empty axis needs extra treatment
403 # because: https://github.com/numpy/numpy/issues/18519 .
404 # We need one dimension for the flattened free_vars and the
405 # non_var_shape does not change.
406 expr.value = jnp.reshape(
407 expr.value,
408 shape=tuple(
409 [np.prod(expr.value.shape[: len(expr.free_vars)])]
410 + list(non_var_shape)
411 ),
412 )
413 else:
414 expr.value = jnp.reshape(
415 expr.value, shape=tuple([-1] + list(non_var_shape))
416 )
417 expr.free_vars = ["_flat_" + "_".join(expr.free_vars)]
418 return exprs
421class Predicate(nnx.Module):
422 """Class representing an LTN predicate.
424 An LTN predicate is grounded as a mathematical
425 function (either pre-defined or learnable) that maps from some n-ary
426 domain of individuals to a real number in [0,1] (fuzzy), which can be
427 interpreted as a truth value.
429 In LTNtorch, the inputs of a predicate are automatically broadcasted
430 before the computation of the predicate, if necessary. Moreover, the
431 output is organized in a tensor where each dimension is related to one
432 variable given in input.
434 Attributes:
435 model: A `nnx.Module` that evaluates this function.
437 Note:
438 - `model` will be called with a tensor that has a batch dimension and
439 optionally feature dimensions. The batch dimension is the axis that
440 results from flattening the axes of the free variables. See
441 [_flatten_free_dims][ltnjax.core._flatten_free_dims].
442 """
444 def __init__(
445 self, model: nnx.Module | None = None, func: Callable | None = None
446 ) -> None:
447 """Constructor.
449 Initializes the LTN predicate in two different ways:
450 1. if `model` is not None, it initializes the predicate with the given
451 nnx.Module;
452 2. if `model` is None, it uses the `func` as a function to define
453 the LTN predicate. Note that, in this case, the LTN predicate is not
454 learnable. So, the lambda function has to be used only for simple
455 predicates.
457 Args:
458 model: (default=None) A `nnx.Module` that evaluates this
459 function.
460 func: (default=None) A lambda_expression.
462 Raises:
463 ValueError: If either both `model` and `func` is given or not
464 `model` nor `func` is given.
465 TypeError: If `model` is given and not an `nnx.Module` or
466 if `func` is given and not an `types.LambdaType` object.
467 """
468 if model is not None and func is not None: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true
469 raise ValueError(
470 "Both model and func parameters have been "
471 "specified. Expected only one of the two "
472 "parameters to be specified."
473 )
475 if model is None and func is None: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true
476 raise ValueError(
477 "Both model and func parameters have not been "
478 "specified. Expected one of the two parameters "
479 "to be specified."
480 )
482 if model is not None:
483 if not isinstance(model, nnx.Module): 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true
484 raise TypeError(
485 "Predicate() : argument 'model' (position 1) "
486 "must be a nnx.Module, not " + str(type(model))
487 )
488 self.model = model
489 else: # func is not None
490 if not isinstance(func, types.LambdaType): 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true
491 raise TypeError(
492 "Predicate() : argument 'func' (position 2) "
493 "must be a function, not " + str(type(model))
494 )
495 self.model = LambdaModel(func)
497 def __call__(self, *inputs: LTNObject, **kwargs: Any) -> LTNObject:
498 """Evaluates the `model` of the given `inputs` and `kwargs`.
500 Args:
501 inputs: tuple of [LTNObject][ltnjax.core.LTNObject] to apply on
502 `model`.
503 kwargs: Further arguments to pass to `model`.
505 Returns:
506 `model(inputs)`.
508 Raises:
509 TypeError: If `inputs` are not of type
510 [LTNObject][ltnjax.core.LTNObject].
511 """
512 # check input
513 inputs_as_list = list(inputs)
514 for x in inputs_as_list:
515 if not isinstance(x, LTNObject):
516 raise TypeError(
517 "The input to a LTN Predicate should be "
518 f"instances of {LTNObject}. Got an instance "
519 f"of {type(x)} instead."
520 )
522 # forward input
523 inputs_as_list = _broadcast_exprs(inputs_as_list)
524 # invariant: flat_inputs has shape: (flat) + feature_dims
525 flat_inputs = _flatten_free_dims(inputs_as_list)
526 # invariant: t_outputs has shape: (batch) + model_output
527 t_outputs = self.model(*_as_arrays(flat_inputs), **kwargs)
529 # recover shape
530 free_vars = (
531 inputs_as_list[0].free_vars if len(inputs_as_list) > 0 else []
532 )
533 free_dims = (
534 jnp.shape(inputs_as_list[0].value)[: len(free_vars)]
535 if len(inputs_as_list) > 0
536 else ()
537 )
539 # Case: Predicate
540 # This line differs in function vs predicate.
541 t_outputs = jnp.reshape(t_outputs, free_dims)
543 # ensure that values are float
544 t_outputs = jnp.astype(t_outputs, jnp.float32)
545 wff = LTNObject(t_outputs, free_vars)
546 return wff
549class Function(nnx.Module):
550 """Class representing LTN functions.
552 A function that maps $n$ tensors of any rank to
553 one single tensor of any rank.
555 Attributes:
556 model: A `nnx.Module` that evaluates this function.
558 Note:
559 - `model` will be called with a tensor that has a batch dimension and
560 optionally feature dimensions. The batch dimension is the axis that
561 results from flattening the axes of the free variables. See
562 [_flatten_free_dims][ltnjax.core._flatten_free_dims].
563 """
565 def __init__(
566 self,
567 model: nnx.Module | None = None,
568 func: types.LambdaType | None = None,
569 ) -> None:
570 """Constructor.
572 Initializes the LTN predicate in two different ways:
573 1. if `model` is not None, it initializes the predicate with the given
574 nnx.Module;
575 2. if `model` is None, it uses the `func` as a function to define
576 the LTN predicate. Note that, in this case, the LTN predicate is not
577 learnable. So, the lambda function has to be used only for simple
578 predicates.
580 Args:
581 model: (default=None) A `nnx.Module` that evaluates this
582 function.
583 func: (default=None) A lambda_expression.
585 Raises:
586 ValueError: If either both `model` and `func` is given or not
587 `model` nor `func` is given.
588 TypeError: If `model` is given and not an `nnx.Module` or
589 if `func` is given and not an `types.LambdaType` object.
590 """
591 if model is not None and func is not None: 591 ↛ 592line 591 didn't jump to line 592 because the condition on line 591 was never true
592 raise ValueError(
593 "Both model and func parameters have been "
594 "specified. Expected only one of the two "
595 "parameters to be specified."
596 )
598 if model is None and func is None: 598 ↛ 599line 598 didn't jump to line 599 because the condition on line 598 was never true
599 raise ValueError(
600 "Both model and func parameters have not been "
601 "specified. Expected one of the two parameters "
602 "to be specified."
603 )
605 if model is not None:
606 if not isinstance(model, nnx.Module): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true
607 raise TypeError(
608 "Function() : argument 'model' (position 1) "
609 "must be a nnx.Module, not " + str(type(model))
610 )
611 self.model = model
612 else: # func is not None
613 if not isinstance(func, types.LambdaType): 613 ↛ 614line 613 didn't jump to line 614 because the condition on line 613 was never true
614 raise TypeError(
615 "Function() : argument 'func' (position 2) "
616 "must be a function, not " + str(type(model))
617 )
618 self.model = LambdaModel(func)
620 def __call__(self, *inputs: LTNObject, **kwargs: Any) -> LTNObject:
621 """Evaluates the `model` of the given `inputs` and `kwargs`.
623 Args:
624 inputs: tuple of [LTNObject][ltnjax.core.LTNObject] to apply on
625 `model`.
626 kwargs: Further arguments to pass to `model`.
628 Returns:
629 `model(inputs)`.
631 Raises:
632 TypeError: If `inputs` are not of type
633 [LTNObject][ltnjax.core.LTNObject].
634 """
635 # check input
636 inputs_as_list = list(inputs)
637 for x in inputs_as_list:
638 if not isinstance(x, LTNObject):
639 raise TypeError(
640 "The input to a LTN Function should be "
641 f"instances of {LTNObject}. Got an instance "
642 f"of {type(x)} instead."
643 )
645 # forward input
646 inputs_as_list = _broadcast_exprs(inputs_as_list)
647 # invariant: flat_inputs has shape: (flat) + feature_dims
648 flat_inputs = _flatten_free_dims(inputs_as_list)
649 # invariant: t_outputs has shape: (batch) + model_output
650 t_outputs = self.model(*_as_arrays(flat_inputs), **kwargs)
652 # recover shape
653 free_vars = (
654 inputs_as_list[0].free_vars if len(inputs_as_list) > 0 else []
655 )
656 free_dims = (
657 jnp.shape(inputs_as_list[0].value)[: len(free_vars)]
658 if len(inputs_as_list) > 0
659 else ()
660 )
662 # Case: Function
663 # This line differs in function vs predicate.
664 t_outputs = jnp.reshape(
665 t_outputs, tuple(list(free_dims) + list(jnp.shape(t_outputs)[1::]))
666 )
668 # ensure that values are float
669 t_outputs = jnp.astype(t_outputs, jnp.float32)
670 wff = LTNObject(t_outputs, free_vars)
671 return wff
674class LambdaModel(nnx.Module):
675 """Simple `nnx.Module` that implements a lambda function.
677 Attributes:
678 lambda_operator: Lambda expression.
679 """
681 def __init__(self, lambda_operator: Callable) -> None:
682 """Constructor.
684 Args:
685 lambda_operator: Lambda expression.
686 """
687 super().__init__()
688 self.lambda_layer = lambda_operator
690 def __call__(self, *inputs: ArrayLike) -> Array:
691 """Calls the lambda function with `inputs`.
693 Args:
694 inputs: Array.
696 Returns:
697 `lambda_layer(inputs)`.
698 """
699 return self.lambda_layer(*inputs)
702def diag(*variables: Variable) -> list[Variable]:
703 """Diagonalizes `variables`.
705 This diagonalizes a list of given [Variable][ltnjax.core.Variable] objects,
706 i.e. this prepares the variables for the use of
707 [Quantifier][ltnjax.core.Quantifier].
709 Args:
710 variables: Tuple of [Variable][ltnjax.core.Variable] objects to
711 diagonalize.
713 Returns:
714 Tuple of [Variable][ltnjax.core.Variable] objects, but in diagonalized
715 form.
717 Raises:
718 TypeError: If `variables` are not of type
719 [Variable][ltnjax.core.Variable].
720 ValueError: If a variable in `variables` starts with `diag_`.
721 """
722 variables = list(variables) # type: ignore
723 # check if a list of LTN variables has been passed
724 if not all(isinstance(x, Variable) for x in variables): 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true
725 raise TypeError(
726 "Expected parameter 'vars' to be a tuple of Variable, "
727 "but got " + str([type(v) for v in variables])
728 )
730 # check if variables are already diagged.
731 for var in variables:
732 if var.free_vars[0].startswith("diag_"):
733 raise ValueError(
734 "Trying to diag a variable that is already temporarily "
735 f"diagged: {var.label}."
736 )
738 diag_label = "diag_" + "_".join([var.label for var in variables])
739 for var in variables:
740 var.free_vars = [diag_label]
741 return variables # type: ignore
744def undiag(*variables: Variable) -> list[Variable]:
745 """Resets the `LTN broadcasting` for the given LTN variables.
747 In other words, it removes the `diagonal quantification` setting from the
748 given variables.
750 Args:
751 variables: Tuple of [Variable][ltnjax.core.Variable] objects
752 Tuple of LTN [Variable][ltnjax.core.Variable] objects for which the
753 `diagonal quantification` setting has to be removed.
755 Returns:
756 List of the same LTN [Variable][ltnjax.core.Variable] objects given in
757 input, with the `diagonal quantification` setting removed.
759 Raises:
760 TypeError: If `variables` are not of type
761 [Variable][ltnjax.core.Variable].
762 """
763 variables = list(variables) # type: ignore
764 # check if a list of LTN variables has been passed
765 if not all(isinstance(x, Variable) for x in variables): 765 ↛ 766line 765 didn't jump to line 766 because the condition on line 765 was never true
766 raise TypeError(
767 "Expected parameter 'vars' to be a tuple of Variable, "
768 "but got " + str([type(v) for v in variables])
769 )
771 for var in variables:
772 var.free_vars = [var.label]
773 return variables # type: ignore
776def _as_arrays(exprs: Sequence[LTNObject]) -> list[Array]:
777 """Converts LTN objects to arrays.
779 This function takes a list of [LTNObject][ltnjax.core.LTNObject] objects
780 and outputs a list of its values.
782 Args:
783 exprs: List of [LTNObject][ltnjax.core.LTNObject] objects.
785 Returns:
786 List of values of the given [LTNObject][ltnjax.core.LTNObject].
788 Note:
789 We only need to take sequences as they only must read-only.
790 """
791 return [expr.value for expr in exprs]
794def _broadcast_exprs(
795 exprs: list[LTNObject], in_place: bool = False
796) -> list[LTNObject]:
797 """Broadcasts variables of `exprs`.
799 This collects the union of free_vars in `exprs` and broadcasts it to
800 each LTN object in `exprs`, i.e. at the end each free variable in `exprs`
801 is contained in each LTN object in `exprs` together with its own axis.
803 Args:
804 exprs: List of [LTNObject][ltnjax.core.LTNObject] objects thats
805 variables will be broadcasted.
806 in_place: (default=False) Boolean that decides whether we perform the
807 operation on a new copy or on the same LTN objects.
809 Returns:
810 List of LTN objects but with broadcasted free variables.
812 Example:
813 `exprs[0]` has shape `(5, 2)` and `free_vars = ['x']` and
814 `exprs[1]` has shape `(10, 2)` and `free_vars = ['y']`.
815 After broadcasting, both LTN objects will have shape `(5, 10, 2)` and
816 `free_vars = ['x', 'y']`.
817 >>> import ltn
818 >>> import numpy as np
819 >>> x = ltn.core.LTNObject(np.ones((5,2)), 'x')
820 >>> y = ltn.core.LTNObject(np.ones((10,2)), 'y')
821 >>> res = ltn.core._broadcast_exprs([x, y])
822 >>> print(res[0])
823 LTNObject(
824 value=Array(shape=(5, 10, 2), dtype=dtype('float32')),
825 free_vars=['x', 'y']
826 )
827 >>> print(res[1])
828 LTNObject(
829 value=Array(shape=(5, 10, 2), dtype=dtype('float32')),
830 free_vars=['x', 'y']
831 )
832 """
833 # measure dimensions for each free variable
834 free_var_to_dim = {}
835 for expr in exprs:
836 for free_var in expr.free_vars:
837 free_var_to_dim[free_var] = expr._get_dim_of_free_var(free_var)
838 free_vars = list(free_var_to_dim.keys())
839 # broadcast
840 if not in_place: 840 ↛ 842line 840 didn't jump to line 842 because the condition on line 840 was always true
841 exprs = [expr._copy() for expr in exprs]
842 for expr in exprs:
843 free_vars_in_arg = list(expr.free_vars)
844 free_vars_not_in_arg = list(
845 set(free_vars).difference(free_vars_in_arg)
846 )
847 for new_free_var in free_vars_not_in_arg:
848 new_idx = len(free_vars_in_arg)
849 expr.value = jnp.expand_dims(expr.value, axis=new_idx)
850 expr.value = jnp.repeat(
851 expr.value, free_var_to_dim[new_free_var], axis=new_idx
852 )
853 free_vars_in_arg.append(new_free_var)
854 perm = [
855 free_vars_in_arg.index(free_var) for free_var in free_vars
856 ] + list(range(len(free_vars_in_arg), jnp.ndim(expr.value)))
857 expr.value = jnp.transpose(expr.value, axes=perm)
858 expr.free_vars = free_vars
859 return exprs
862class Connective:
863 """Class representing an LTN connective.
865 Wrapper for connectives that aggregates given
866 [LTNObject][ltnjax.core.LTNObject] objects according a given aggregator
867 operation `connective_op` and also broadcasts variables, see
868 [_broadcast_exprs][ltnjax.core._broadcast_exprs].
870 Attributes:
871 connective_op: Aggregation function.
872 """
874 def __init__(
875 self, connective_op: ltn.fuzzy_ops.ConnectiveOperator
876 ) -> None:
877 """Constructor.
879 Args:
880 connective_op: Aggregation function.
881 """
882 self.connective_op = connective_op
884 def __call__(self, *wffs: LTNObject, **kwargs: Any) -> LTNObject:
885 """Applies the connective using the given `connective_op`.
887 Args:
888 wffs: Tuple of LTN objects.
889 kwargs: Further arguments to pass to `connective_op`.
891 Returns:
892 The resulting [LTNObject][ltnjax.core.LTNObject] object that
893 combines the given `wffs` into one joint LTN objects.
895 Raises:
896 TypeError: If `wffs` are not of type
897 [LTNObject][ltnjax.core.LTNObject].
898 ValueError: If number of `wffs` does not fit to `connective_op`.
899 """
900 wffs = list(wffs) # type: ignore
901 for x in wffs:
902 if not isinstance(x, LTNObject):
903 raise TypeError(
904 "The operands of a LTN connective should be "
905 f"instances of {LTNObject}. Got an instance "
906 f"of {type(x)} instead."
907 )
909 wffs = _broadcast_exprs(wffs) # type: ignore
910 if isinstance(
911 self.connective_op, ltn.fuzzy_ops.UnaryConnectiveOperator
912 ):
913 if len(wffs) != 1: 913 ↛ 914line 913 didn't jump to line 914 because the condition on line 913 was never true
914 raise ValueError(
915 "wffs must have length 1 since connective_op "
916 "is an UnaryConnectiveOperator."
917 )
918 t_result = self.connective_op(*_as_arrays(wffs), **kwargs)
919 elif isinstance(
920 self.connective_op, ltn.fuzzy_ops.BinaryConnectiveOperator
921 ):
922 if len(wffs) != 2: 922 ↛ 923line 922 didn't jump to line 923 because the condition on line 922 was never true
923 raise ValueError(
924 "wffs must have length 2 since connective_op "
925 "is an BinaryConnectiveOperator."
926 )
927 t_result = self.connective_op(*_as_arrays(wffs), **kwargs)
928 elif isinstance(self.connective_op, ltn.fuzzy_ops.AggregationOperator): 928 ↛ 934line 928 didn't jump to line 934 because the condition on line 928 was always true
929 if len(wffs) < 1: 929 ↛ 930line 929 didn't jump to line 930 because the condition on line 929 was never true
930 raise ValueError("wffs must have length at least 1.")
931 t_result = self.connective_op(
932 jnp.stack(_as_arrays(wffs)), axis=0, **kwargs
933 )
934 result = LTNObject(t_result, wffs[0].free_vars)
935 return result
938class Quantifier:
939 r"""Class representing an LTN quantifier.
941 Wrapper for Quantifiers. This evaluates a given LTN object `wff` for
942 all variable combinations for that the condition `mask` is true. Then, the
943 results will be aggregated with the aggregation operator `aggreg_op`.
945 Attributes:
946 aggreg_op: Aggregation operator.
947 quantifier: (str = "f" | "e"`) Decides whether this is a
948 "forall"- or "exists"-quantifier. This has no effect on the
949 aggregator but is important for cases, where we aggregate over
950 $\emptyset$. This may happen if the variables are empty
951 or the mask masks each variable-combination. In these cases,
952 "forall" expressions are true while "exists"-quantifiers are
953 false.
954 If the Quantifier is used with non-truth values, the quantifier
955 can be used as the <b>neural elements</b> like for `sum` or
956 `prod`.
958 Raises:
959 TypeError: If `aggreg_op` is not of type
960 [ConnectiveOperator][ltnjax.fuzzy_ops.ConnectiveOperator].
961 ValueError: If `quantifier` is not one of the strings `forall` or
962 `exists`.
964 Note:
965 It is possible that the variable-combinations are empty or that the
966 condition `mask` will mask every variable-combination. In both cases,
967 a "forall"-statement will always be true while an "exists"-statement
968 will always be false.
969 """
971 def __init__(
972 self, aggreg_op: ltn.fuzzy_ops.ConnectiveOperator, quantifier: str
973 ) -> None:
974 r"""Constructor.
976 Args:
977 aggreg_op: Aggregation operator.
978 quantifier: (str = "f" | "e"`) Decides whether this is a
979 "forall"- or "exists"-quantifier. This has no effect on the
980 aggregator but is important for cases, where we aggregate over
981 $\emptyset$. This may happen if the variables are empty
982 or the mask masks each variable-combination. In these cases,
983 "forall" expressions are true while "exists"-quantifiers are
984 false.
985 If the Quantifier is used with non-truth values, the quantifier
986 can be used as the <b>neural elements</b> like for `sum` or
987 `prod`.
988 """
989 self.aggreg_op = aggreg_op
990 if not isinstance(aggreg_op, ltn.fuzzy_ops.ConnectiveOperator):
991 raise TypeError(
992 "The aggregation operator for the quantifier "
993 "should be an instance of "
994 f"{ltn.fuzzy_ops.ConnectiveOperator}. Got an "
995 f"instance of {type(aggreg_op)} instead."
996 )
997 if quantifier not in ["f", "e"]:
998 raise ValueError(
999 '`quantifier` for the quantifier should be "f" or "e".'
1000 )
1001 self.quantifier = quantifier
1003 def __call__(
1004 self,
1005 variables: Variable | list[Variable],
1006 wff: LTNObject,
1007 mask: LTNObject | None = None,
1008 **kwargs: Any,
1009 ) -> LTNObject:
1010 """Applies the quantification and outputs the resulting LTN object.
1012 As a side-effect, this removes the `diagonal quantification` from the
1013 given `variables`.
1014 Refer to [undiag][ltnjax.core.undiag].
1016 Args:
1017 variables: Variable or list of variables.
1018 wff: LTN object.
1019 mask: (default=None) Condition operation.
1020 kwargs: Further arguments to pass to `connective_op`.
1022 Returns:
1023 The resulting LTN object.
1025 Raises:
1026 TypeError: If the values of `variables` are not instances of
1027 [Variable][ltnjax.core.Variable] or if `wff` is not of type
1028 [LTNObject][ltnjax.core.LTNObject].
1029 """
1030 # check inputs
1031 variables = (
1032 [variables] if not isinstance(variables, list) else variables
1033 )
1034 for x in variables:
1035 if not isinstance(x, Variable):
1036 raise TypeError(
1037 "The quantified variables should be "
1038 f"instances of {Variable}. Got an instance of "
1039 "{type(x)} instead."
1040 )
1041 if not isinstance(wff, LTNObject): 1041 ↛ 1042line 1041 didn't jump to line 1042 because the condition on line 1041 was never true
1042 raise TypeError(
1043 "The quantified LTN object should be an instance "
1044 f"of {LTNObject}. Got an instance of {type(x)} "
1045 "instead ."
1046 )
1047 if mask is not None and not isinstance(mask, LTNObject):
1048 raise TypeError(
1049 "The mask argument should be an instance of "
1050 f"{LTNObject}. Got an instance of {type(mask)} "
1051 "instead."
1052 )
1054 # Note: For the edge-case that variables are empty.
1055 # Since the empty variables are already passed into predicates that
1056 # return size-0 arrays, we have to check wff.
1057 # This will be done before self.aggreg_op is applied.
1059 aggreg_vars = {var.free_vars[0] for var in variables}
1061 if mask is not None:
1062 # This block is for broadcasting variables in wff and the
1063 # variables in mask.
1064 # Important to put aggreg dims last, to keep other dims in the
1065 # ragged result.
1066 mask = Quantifier._transpose_free_vars(
1067 mask,
1068 new_var_order=[
1069 var for var in mask.free_vars if var not in aggreg_vars
1070 ]
1071 + [var for var in mask.free_vars if var in aggreg_vars],
1072 )
1073 wff = Quantifier._broadcast_wff_and_mask(wff, mask)
1075 mask.value = jnp.astype(mask.value, jnp.bool)
1077 # Ignore vars in variables that do not occur in wff.free_vars.
1078 aggreg_axes = [
1079 wff.free_vars.index(var)
1080 for var in aggreg_vars
1081 if var in wff.free_vars
1082 ]
1083 if wff.value.size == 0: # Check edge case, when wff is empty:
1084 t_result = (
1085 jnp.prod(wff.value, axis=aggreg_axes, **kwargs)
1086 if self.quantifier == "f"
1087 else jnp.sum(wff.value, axis=aggreg_axes, **kwargs)
1088 )
1089 else:
1090 t_result = self.aggreg_op(
1091 wff.value, axis=aggreg_axes, mask=mask.value, **kwargs
1092 )
1094 aggreg_axes_in_mask = [
1095 mask.free_vars.index(var)
1096 for var in aggreg_vars
1097 if var in mask.free_vars
1098 ]
1099 # empty_vars are the variable-combinations that are completely
1100 # masked and we need to apply replacement value `rep_value`.
1101 non_empty_vars = (
1102 jnp.sum(
1103 jnp.astype(mask.value, jnp.int32), axis=aggreg_axes_in_mask
1104 )
1105 != 0
1106 )
1107 rep_value = 1.0 if self.quantifier == "f" else 0
1109 t_result = jnp.where(non_empty_vars, t_result, rep_value)
1110 else:
1111 # ignore vars in variables that do not occur in wff.free_vars
1112 aggreg_axes = [
1113 wff.free_vars.index(var)
1114 for var in aggreg_vars
1115 if var in wff.free_vars
1116 ]
1117 if wff.value.size == 0: # Check edge case, when wff is empty:
1118 t_result = (
1119 jnp.prod(wff.value, axis=aggreg_axes, **kwargs)
1120 if self.quantifier == "f"
1121 else jnp.sum(wff.value, axis=aggreg_axes, **kwargs)
1122 )
1123 else:
1124 t_result = self.aggreg_op(
1125 wff.value, axis=aggreg_axes, **kwargs
1126 )
1127 free_vars_remaining = [
1128 var for var in wff.free_vars if var not in aggreg_vars
1129 ]
1130 result = LTNObject(t_result, free_vars_remaining)
1131 undiag(*variables)
1132 return result
1134 @staticmethod
1135 def _broadcast_wff_and_mask(
1136 wff: LTNObject, mask: LTNObject, in_place: bool = False
1137 ) -> LTNObject:
1138 """Broadcasts the free variables from `mask` to `wff`.
1140 The variables of `mask` are put in the first axes.
1142 Args:
1143 wff: LTN object.
1144 mask: LTN object.
1145 in_place: (default=False) Boolean that decides whether we perform
1146 the operation on `wff` or a new copy.
1148 Returns:
1149 The LTN object `wff` with the vars from `mask` added.
1150 """
1151 if not in_place: 1151 ↛ 1155line 1151 didn't jump to line 1155 because the condition on line 1151 was always true
1152 wff = wff._copy()
1153 # 1. Broadcast wff with vars that are in the mask but not yet in the
1154 # LTN object.
1155 mask_vars_not_in_wff = [
1156 var for var in mask.free_vars if var not in wff.free_vars
1157 ]
1158 for var in mask_vars_not_in_wff:
1159 new_idx = len(wff.free_vars)
1160 wff.value = jnp.expand_dims(wff.value, axis=new_idx)
1161 wff.value = jnp.repeat(
1162 wff.value, mask._get_dim_of_free_var(var), axis=new_idx
1163 )
1164 wff.free_vars.append(var)
1165 # 2. Transpose wff so that the masked vars on the first axes.
1166 vars_not_in_mask = [
1167 var for var in wff.free_vars if var not in mask.free_vars
1168 ]
1169 wff = Quantifier._transpose_free_vars(
1170 wff, new_var_order=mask.free_vars + vars_not_in_mask
1171 )
1172 return wff
1174 @staticmethod
1175 def _transpose_free_vars(
1176 expr: LTNObject, new_var_order: list[VarLabel], in_place: bool = False
1177 ) -> LTNObject:
1178 """Transposes free variables.
1180 This changes the order of variables in `expr.free_vars` and the
1181 axes of `expr.value` will be transposed accordingly.
1183 Args:
1184 expr: The LTN object whose variables will be transposed.
1185 new_var_order: List of variables that defines the new order.
1186 in_place: (default=False) Boolean that decides whether we perform
1187 the operation on a new copy or on the same LTN objects.
1189 Returns:
1190 The transposed LTN object.
1191 """
1192 permutation = [expr.free_vars.index(var) for var in new_var_order]
1193 if not in_place: 1193 ↛ 1195line 1193 didn't jump to line 1195 because the condition on line 1193 was always true
1194 expr = expr._copy()
1195 expr.value = jnp.transpose(expr.value, permutation)
1196 expr.free_vars = new_var_order
1197 return expr