Coverage for src/ltnjax/fuzzy_ops.py: 81%
341 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-26 11:35 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-26 11:35 +0000
1# SPDX-FileCopyrightText: 2026 German Aerospace Center (DLR)
2# SPDX-License-Identifier: MIT
3#
4from abc import ABC, abstractmethod
5from collections.abc import Sequence
6from typing import Any
7import warnings
8from warnings import warn
10# Annotations
11from jax import Array # should be used for outputs
12import jax.numpy as jnp
13from jax.typing import ArrayLike # should be used for inputs
16"""
17The [fuzzy_ops][ltnjax.fuzzy_ops] module contains the Jax implementation of
18some common fuzzy logic operators and aggregators.
19Refer to the `LTN paper <https://arxiv.org/abs/2012.13635>` for a detailed
20description of these operators
21(see the Appendix).
23All the operators included in this module support the traditional
24NumPy/Jax broadcasting.
26The operators have been designed to be used with
27[Connective][ltnjax.core.Connective] or [Quantifier][ltnjax.core.Quantifier].
28"""
30Axis = int | Sequence[int] | None
31eps = 1e-4
34def not_zeros(x: ArrayLike) -> Array:
35 """Smoothly transforms an array to avoid zero-values.
37 Function that has to be used when we need to assure that the truth value
38 in input to a fuzzy operator is never equal to zero, in such a way to
39 avoid gradient problems. It maps the interval $[0, 1]$ in the
40 interval $]0, 1]$, where the $0$ is excluded.
42 Args:
43 x: Array of truth-values.
45 Returns:
46 The input truth values changed in such a way to prevent gradient
47 problems (0 is changed with a small number near 0).
48 """
49 return jnp.multiply((1 - eps), x + eps)
52def not_ones(x: ArrayLike) -> Array:
53 """Smoothly transforms an array to avoid one-values.
55 Function that has to be used when we need to assure that the truth value
56 in input to a fuzzy operator is never equal to one, in such a way to avoid
57 gradient problems. It maps the interval $[0, 1]$ in the interval
58 $[0, 1[$, where the $1$ is excluded.
60 Args:
61 x: Array of truth-values.
63 Returns:
64 The input truth values changed in such a way to prevent gradient
65 problems (1 is changed with a small number near 1).
66 """
67 return jnp.multiply((1 - eps), x)
70def sigmoid(x: ArrayLike) -> Array:
71 """Computes sigmoid.
73 Args:
74 x: Array.
76 Returns:
77 `sigmoid(x)`.
78 """
79 return 1 / (1 + jnp.exp(jnp.negative(x)))
82def tanh(x: ArrayLike) -> Array:
83 """Computes hyperbolic tangent (tanh).
85 Args:
86 x: Array.
88 Returns:
89 `tanh(x)`.
90 """
91 return (jnp.tanh(x) + 1) / 2
94class ConnectiveOperator(ABC):
95 """Abstract class for connective operators.
97 Raises:
98 NotImplementedError: Raised when
99 [__call__][ltnjax.fuzzy_ops.ConnectiveOperator.__call__]
100 is not implemented in the sub-class.
101 """
103 @abstractmethod
104 def __call__(self, *args: Any, **kwargs: Any):
105 """Implements the behavior of the connective operator.
107 Args:
108 args: Arguments.
109 kwargs: Keyword arguments.
111 Raises:
112 NotImplementedError: Always.
113 """
114 raise NotImplementedError()
117class UnaryConnectiveOperator(ConnectiveOperator):
118 """Abstract class for unary connective operators.
120 Raises:
121 NotImplementedError: Raised when
122 [__call__][ltnjax.fuzzy_ops.UnaryConnectiveOperator.__call__] is
123 not implemented in the sub-class.
124 """
126 @abstractmethod
127 def __call__(self, *args: Any, **kwargs: Any):
128 """Implements the behavior of the unary connective operator.
130 Args:
131 args: Arguments.
132 kwargs: Keyword arguments.
134 Raises:
135 NotImplementedError: Always.
136 """
137 raise NotImplementedError()
140class BinaryConnectiveOperator(ConnectiveOperator):
141 """Abstract class for binary connective operators.
143 Raises:
144 NotImplementedError: Raised when
145 [__call__][ltnjax.fuzzy_ops.BinaryConnectiveOperator.__call__] is
146 not implemented in the sub-class.
147 """
149 @abstractmethod
150 def __call__(self, *args: Any, **kwargs: Any):
151 """Implements the behavior of the connective binary operator.
153 Args:
154 args: Arguments.
155 kwargs: Keyword arguments.
157 Raises:
158 NotImplementedError: Always.
159 """
160 raise NotImplementedError()
163class AggregationOperator(ConnectiveOperator):
164 """Abstract class for aggregation operators.
166 Raises:
167 NotImplementedError: Raised when
168 [__call__][ltnjax.fuzzy_ops.AggregationOperator.__call__]
169 is not implemented in the sub-class.
170 """
172 @abstractmethod
173 def __call__(self, *args: Any, **kwargs: Any):
174 """Implements the behavior of the aggregation operator.
176 Args:
177 args: Arguments.
178 kwargs: Keyword arguments.
180 Raises:
181 NotImplementedError: Always.
182 """
183 raise NotImplementedError()
186class NotStandard(UnaryConnectiveOperator):
187 """Standard fuzzy negation operator.
189 $1-x$
190 """
192 def __call__(self, x: ArrayLike) -> Array:
193 """It applies the standard fuzzy negation operator to the given
194 operand.
196 Args:
197 x: Operand on which the operator has to be applied.
199 Returns:
200 The standard fuzzy negation of the given operand.
201 """
202 return jnp.subtract(1.0, x)
205class NotGodel(UnaryConnectiveOperator):
206 """Godel fuzzy negation operator.
208 $x == 0$
210 Notes:
211 - This is not recommended for machine learning as the gradient will always
212 be $0$.
213 """
215 def __call__(self, x: ArrayLike) -> Array:
216 """It applies the Godel fuzzy negation operator to the given operand.
218 Args:
219 x: Operand on which the operator has to be applied.
221 Returns:
222 The Godel fuzzy negation of the given operand.
223 """
224 return jnp.astype(jnp.equal(x, 0), jnp.float32)
227class AndMin(BinaryConnectiveOperator):
228 r"""Godel fuzzy conjunction operator (min operator).
230 $\land_{Godel}(x, y) := \min(x,y)$
232 Notes:
233 - $\min$ has a <b>single-passing gradient</b> if the values
234 $x$ and $y$ are not equal, i.e. the smaller one has $1$
235 and the other one is $0$. If both values are equal, both gradients
236 equal $0.5$.
237 """
239 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
240 """It applies the Godel fuzzy conjunction operator to the given
241 operands.
243 Args:
244 x: First operand on which the operator has to be applied.
245 y: Second operand on which the operator has to be applied.
247 Returns:
248 The Godel fuzzy conjunction of the two operands.
249 """
250 return jnp.minimum(x, y)
253class AndProd(BinaryConnectiveOperator):
254 r"""Goguen fuzzy conjunction operator (product operator).
256 $\land_{Goguen}(x, y) = xy$
258 Attributes:
259 stable: (default=True) Flag indicating whether to use the
260 [stable](../../stable.md) version of the operator or not.
262 Notes:
263 - product t-norm has <b>vanishing gradients</b> for $x=y=0$.
264 """
266 def __init__(self, stable: bool = True):
267 """
268 Constructor.
270 Args:
271 stable: (default=True) Flag indicating whether to use the
272 [stable](../../stable.md) version of the operator or not.
273 """
274 self.stable = stable
276 def __call__(
277 self, x: ArrayLike, y: ArrayLike, stable: bool | None = None
278 ) -> Array:
279 """It applies the Goguen fuzzy conjunction operator to the given
280 operands.
282 Args:
283 x: First operand on which the operator has to be applied.
284 y: Second operand on which the operator has to be applied.
285 stable: (default=None) Flag indicating whether to use the
286 [stable](../../stable.md) version of the operator or not.
288 Returns:
289 The Goguen fuzzy conjunction of the two operands.
290 """
291 stable = self.stable if stable is None else stable
292 if stable: 292 ↛ 294line 292 didn't jump to line 294 because the condition on line 292 was always true
293 x, y = not_zeros(x), not_zeros(y)
294 return jnp.multiply(x, y)
297class AndLuk(BinaryConnectiveOperator):
298 r"""Lukasiewicz fuzzy conjunction operator.
300 $\land_{Lukasiewicz}(x, y) = \max(x + y - 1, 0)$
302 Notes:
303 - And_Luk has vanishing gradients if $x+y < 1.$.
304 If $x+y = 1$, both gradients are $0.5$.
305 """
307 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
308 """It applies the Lukasiewicz fuzzy conjunction operator to the given
309 operands.
311 Args:
312 x: First operand on which the operator has to be applied.
313 y: Second operand on which the operator has to be applied.
315 Returns:
316 The Lukasiewicz fuzzy conjunction of the two operands.
317 """
318 return jnp.maximum(x + y - 1.0, 0.0)
321class OrMax(BinaryConnectiveOperator):
322 r"""Godel fuzzy disjunction operator (max operator).
324 $\lor_{Godel}(x, y) = \max(x, y)$
326 Notes:
327 - $\max$ has a <b>single-passing gradient</b> if the values
328 $x$ and $y$ are not equal,
329 i.e. the larger one has $1$ and the other one is $0$.
330 If both values are equal, both gradients equal $0.5$.
331 """
333 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
334 """It applies the Godel fuzzy disjunction operator to the given
335 operands.
337 Args:
338 x: First operand on which the operator has to be applied.
339 y: Second operand on which the operator has to be applied.
341 Returns:
342 The Godel fuzzy disjunction of the two operands.
343 """
344 return jnp.maximum(x, y)
347class OrProbSum(BinaryConnectiveOperator):
348 r"""Goguen fuzzy disjunction operator (probabilistic sum).
350 $\lor_{Goguen}(x, y) = x + y - xy$
352 Attributes:
353 stable: (default=True) Flag indicating whether to use the
354 [stable](../../stable.md) version of the operator or not.
356 Notes:
357 - The product t-conorm has <b>vanishing gradients</b> for $x=y=1$.
358 - This or operator is implemented using De Morgans's law
359 $u \lor v \Leftrightarrow \neq u \land \neq v$ and the
360 implementations [NotStandard][ltnjax.fuzzy_ops.NotStandard] and
361 [AndProd][ltnjax.fuzzy_ops.AndProd].
362 """
364 def __init__(self, stable: bool = True):
365 """Constructor.
367 Args:
368 stable: (default=True) Flag indicating whether to use the
369 [stable](../../stable.md) version of the operator or not.
370 """
371 self.stable = stable
373 def __call__(
374 self, x: ArrayLike, y: ArrayLike, stable: bool | None = None
375 ) -> Array:
376 """It applies the Goguen fuzzy disjunction operator to the given
377 operands.
379 Args:
380 x: First operand on which the operator has to be applied.
381 y: Second operand on which the operator has to be applied.
382 stable: (default=None) Flag indicating whether to use the
383 [stable](../../stable.md) version of the operator or not.
385 Returns:
386 The Goguen fuzzy disjunction of the two operands.
387 """
388 stable = self.stable if stable is None else stable
389 if stable:
390 x, y = not_ones(x), not_ones(y)
391 return jnp.subtract(jnp.add(x, y), jnp.multiply(x, y))
394class OrLuk(BinaryConnectiveOperator):
395 r"""Lukasiewicz fuzzy disjunction operator.
397 $\lor_{Lukasiewicz}(x, y) = \min(x + y, 1)$
399 Notes:
400 - Or_Luk has <b>vanishing gradients</b> for $x+y>1$.
401 If $x+y=1$, both gradients will be $0.5$.
402 """
404 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
405 """It applies the Lukasiewicz fuzzy disjunction operator to the given
406 operands.
408 Args:
409 x: First operand on which the operator has to be applied.
410 y: Second operand on which the operator has to be applied.
412 Returns:
413 The Lukasiewicz fuzzy disjunction of the two operands.
414 """
415 return jnp.minimum(x + y, 1.0)
418class OrSmoothMaximumUnit(BinaryConnectiveOperator):
419 r"""[Smooth maximum unit](https://en.wikipedia.org/wiki/Smooth_maximum)
420 fuzzy disjunction operator, that approximates the maximum.
422 $\max_\epsilon(a,b) = (a+b+|a-b|_\epsilon) / 2$, where we
423 approximate $|a-b|$ by $\sqrt((a-b)^2 + \epsilon)$.
425 Attributes:
426 epsilon: A parameter for $|a-b|_\epsilon$ that
427 approximates $|a-b|$.
428 """
430 def __init__(self, epsilon: float = 1e-4):
431 r"""This constructor has to be used to set the epsilon parameter.
433 Args:
434 epsilon: A parameter for $|a-b|_\epsilon$ that
435 approximates $|a-b|$.
436 """
437 self.epsilon = epsilon
439 def __call__(
440 self, x: ArrayLike, y: ArrayLike, epsilon: float | None = None
441 ) -> Array:
442 r"""It applies the smooth maximum unit fuzzy disjunction to the given
443 operands.
445 Args:
446 x: First operand on which the operator has to be applied.
447 y: Second operand on which the operator has to be applied.
448 epsilon: (default=None) Parameter for $|a-b|_\epsilon$ that
449 approximates $|a-b|$.
451 Returns:
452 The smooth maximum unit fuzzy disjunction of the two operands.
453 """
454 epsilon = self.epsilon if epsilon is None else epsilon
455 abs = jnp.sqrt(jnp.square(jnp.subtract(x, y)) + self.epsilon)
456 return jnp.divide(x + y + abs, 2)
459class ImpliesKleeneDienes(BinaryConnectiveOperator):
460 r"""Kleene Dienes fuzzy implication operator.
462 $\rightarrow_{KleeneDienes}(x, y) = \max(1 - x, y)$
464 Notes:
465 - Implies_KleeneDienes has a <b>single-passing gradient</b> for either
466 $x$ if $1.-x>y$, or for $y$ if $1.-x<y$.
467 If $1.-x=y$, both gradients are $0.5$.
468 """
470 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
471 """It applies the Kleene Dienes fuzzy implication operator to the given
472 operands.
474 Args:
475 x: First operand on which the operator has to be applied.
476 y: Second operand on which the operator has to be applied.
478 Returns:
479 The Kleene Dienes fuzzy implication of the two operands.
480 """
481 return jnp.maximum(1.0 - x, y)
484class ImpliesGodel(BinaryConnectiveOperator):
485 r"""Godel fuzzy implication operand.
487 $\rightarrow_{Godel}(x, y) = \left\{\begin{array}{ c l }1 & \quad \textrm{if } x \le y \\ y & \quad \textrm{otherwise} \end{array} \right.$
489 Notes:
490 - Implies_Godel has <b>vanishing gradients</b> if $x<=y$.
491 Otherwise, it <b>passes a single gradient</b> for $y$ if $x>y$.
492 """ # noqa: E501
494 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
495 """It applies the Godel fuzzy implication operator to the given
496 operands.
498 Args:
499 x: First operand on which the operator has to be applied.
500 y: Second operand on which the operator has to be applied.
502 Returns:
503 The Godel fuzzy implication of the two operands.
504 """
505 return jnp.where(jnp.less_equal(x, y), jnp.ones_like(x), y)
508class ImpliesReichenbach(BinaryConnectiveOperator):
509 r"""Reichenbach fuzzy implication operator.
511 $\rightarrow_{Reichenbach}(x, y) = 1 - x + xy$
513 Attributes:
514 stable: (default=True) Flag indicating whether to use the
515 [stable](../../stable.md) version of the operator or not.
517 Notes:
518 - The Reichenbach implication has <b>vanishing gradients</b> for
519 $x=0, y=1$.
520 This can be prevented by using its [stable](../../stable.md) version.
521 - This implies operator is implemented using
522 $u \implies v \Leftrightarrow \neq p \lor q$ using
523 [NotStandard][ltnjax.fuzzy_ops.NotStandard] and
524 [OrProbSum][ltnjax.fuzzy_ops.OrProbSum].
525 """
527 def __init__(self, stable: bool = True):
528 """Constructor.
530 Args:
531 stable: (default=True) Flag indicating whether to use the
532 [stable](../../stable.md) version of the operator or not.
533 """
534 self.stable = stable
536 def __call__(
537 self, x: ArrayLike, y: ArrayLike, stable: bool | None = None
538 ) -> Array:
539 """It applies the Reichenbach fuzzy implication operator to the given
540 operands.
542 Args:
543 x: First operand on which the operator has to be applied.
544 y: Second operand on which the operator has to be applied.
545 stable: (default=None) Flag indicating whether to use the
546 [stable](../../stable.md) version of the operator or not.
548 Returns:
549 The Reichenbach fuzzy implication of the two operands.
550 """
551 stable = self.stable if stable is None else stable
552 if stable:
553 x, y = not_zeros(x), not_ones(y)
554 return jnp.add(jnp.subtract(1.0, x), jnp.multiply(x, y))
557class ImpliesGoguen(BinaryConnectiveOperator):
558 r"""Goguen fuzzy implication operator.
560 $\rightarrow_{Goguen}(x, y) = \left\{\begin{array}{ c l }1 & \quad \textrm{if } x \le y \\ \frac{y}{x} & \quad \textrm{otherwise} \end{array} \right.$
562 Parameters:
563 stable: (default=True) Flag indicating whether to use the
564 [stable](../../stable.md) version of the operator or not.
566 Notes:
567 - This expression is <b>only defined</b> if $x != 0$.
568 - This expression has <b>vanishing gradients</b> if $x <= y$.
569 This can be prevented by using its [stable](../../stable.md) version.
570 """ # noqa: E501
572 def __init__(self, stable: bool = True):
573 """Constructor.
575 Args:
576 stable: (default=True) Flag indicating whether to use the
577 [stable](../../stable.md) version of the operator or not.
578 """
579 self.stable = stable
581 def __call__(
582 self, x: ArrayLike, y: ArrayLike, stable: bool | None = None
583 ) -> Array:
584 """It applies the Goguen fuzzy implication operator to the given
585 operands.
587 Args:
588 x: First operand on which the operator has to be applied.
589 y: Second operand on which the operator has to be applied.
590 stable: (default=None) Flag indicating whether to use the
591 [stable](../../stable.md) version of the operator or not.
593 Returns:
594 The Goguen fuzzy implication of the two operands.
595 """
596 stable = self.stable if stable is None else stable
597 if stable:
598 x = not_zeros(x)
599 return jnp.where(
600 jnp.less_equal(x, y), jnp.ones_like(x), jnp.divide(y, x)
601 )
604class ImpliesLuk(BinaryConnectiveOperator):
605 r"""Lukasiewicz fuzzy implication operator.
607 $\rightarrow_{Lukasiewicz}(x, y) = \min(1 - x + y, 1)$
609 Notes:
610 - Implies_Luk has <b>vanishing gradients</b> for $-x+y>0$.
611 For $-x+y=0$, both gradients are $0.5$.
612 """
614 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
615 """It applies the Lukasiewicz fuzzy implication operator to the given
616 operands.
618 Args:
619 x: First operand on which the operator has to be applied.
620 y: Second operand on which the operator has to be applied.
622 Returns:
623 The Lukasiewicz fuzzy implication of the two operands.
624 """
625 return jnp.minimum(1.0 - x + y, 1.0)
628class Implies(BinaryConnectiveOperator):
629 r"""Implies ($\Rightarrow$) fuzzy operator.
631 An Implies operator that uses given negation and disjunction operators.
632 This uses $p \Rightarrow q \equiv \not p \lor q$.
634 Attributes:
635 not_op: Fuzzy negation operator to use for the negation operator.
636 or_op: Fuzzy disjunction operator to use for the disjunction
637 operator.
638 """
640 def __init__(
641 self, not_op: UnaryConnectiveOperator, or_op: BinaryConnectiveOperator
642 ):
643 """Constructor.
645 Args:
646 not_op: Fuzzy negation operator to use for the negation operator.
647 or_op: Fuzzy disjunction operator to use for the disjunction
648 operator.
649 """
650 self.not_op = not_op
651 self.or_op = or_op
653 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
654 """It applies the fuzzy implies operator to the given operands.
656 Args:
657 x: First operand on which the operator has to be applied.
658 y: Second operand on which the operator has to be applied.
660 Returns:
661 The fuzzy implies of the two operands.
662 """
663 return self.or_op(self.not_op(x), y)
666class Equiv(BinaryConnectiveOperator):
667 r"""Equivalence ($\leftrightarrow$) fuzzy operator.
669 $x \leftrightarrow y \equiv x \rightarrow y \land y \rightarrow x$
671 Attributes:
672 and_op: Fuzzy operator for the conjunction.
673 implies_op: Fuzzy operator for the implication.
675 Notes:
676 - the equivalence operator ($\leftrightarrow$) is implemented as an
677 operator which computes: $x \rightarrow y \land y \rightarrow x$;
678 - the `and_op` parameter defines the operator for $\land$;
679 - the `implies_op` parameter defines the operator for $\rightarrow$.
680 """ # noqa: E501
682 def __init__(
683 self,
684 and_op: BinaryConnectiveOperator,
685 implies_op: BinaryConnectiveOperator,
686 ):
687 """This constructor has to be used to set the operator for the
688 conjunction and for the implication of the equivalence operator.
690 Args:
691 and_op: Fuzzy operator for the conjunction.
692 implies_op: Fuzzy operator for the implication.
693 """
694 self.and_op = and_op
695 self.implies_op = implies_op
697 def __call__(self, x: ArrayLike, y: ArrayLike) -> Array:
698 """It applies the fuzzy equivalence operator to the given operands.
700 Args:
701 x: First operand on which the operator has to be applied.
702 y: Second operand on which the operator has to be applied.
704 Returns:
705 The fuzzy equivalence of the two operands.
706 """
707 return self.and_op(self.implies_op(x, y), self.implies_op(y, x))
710class AggregMin(AggregationOperator):
711 r"""Min fuzzy aggregation operator. Intended for conjunction.
713 $A_{T_{M}}(x_1, \dots, x_n) = \min(x_1, \dots, x_n)$
715 Notes:
716 - This aggregator has a <b>single-passing gradient</b> for the minimum
717 value. If $n$ values attain the minimum, they get the gradients
718 $\frac{1}{n}$.
719 """
721 def __call__(
722 self,
723 xs: ArrayLike,
724 axis: Axis = None,
725 keepdims: bool = False,
726 mask: ArrayLike | None = None,
727 ) -> Array:
728 """It applies the min fuzzy aggregation operator to the given
729 expression's grounding on the selected dimensions.
731 Args:
732 xs: Grounding of expression on which the aggregation has to be
733 performed.
734 axis: (default=None) Axis along which the aggregation to be
735 computed. If None, the aggregation is computed along all the
736 axes.
737 keepdims: (default=False) Flag indicating whether the output has to
738 keep the same dimensions as the input after the aggregation.
739 mask: (default=None) Boolean mask for excluding values of 'xs'
740 from the aggregation. It is internally used for guarded
741 quantification. The mask must have the same shape of 'xs'.
742 `False` means exclusion, `True` means inclusion.
744 Returns:
745 Min fuzzy aggregation of the formula.
746 """
747 return jnp.min(
748 xs, axis=axis, keepdims=keepdims, where=mask, initial=1.0
749 )
752class AggregMean(AggregationOperator):
753 r"""(Weighted) Mean fuzzy aggregation operator.
755 [Arithmetic mean](https://en.wikipedia.org/wiki/Arithmetic_mean)
756 $A_{M}(x_1, \dots, x_n) = \frac{1}{n} \sum_{i = 1}^n x_i$
758 [Weighted arithmetic mean](https://en.wikipedia.org/wiki/Weighted_arithmetic_mean)
759 $A_{M}(x_1, \dots, x_n) = \frac{1}{\sum_{i = 1}^n w_i} \sum_{i = 1}^n w_i x_i$
760 """ # noqa: E501
762 def __call__(
763 self,
764 xs: ArrayLike,
765 weights: ArrayLike | None = None,
766 axis: Axis = None,
767 keepdims: bool = False,
768 mask: ArrayLike | None = None,
769 ) -> Array:
770 """It applies the mean fuzzy aggregation operator to the given
771 formula's grounding on the selected dimensions.
773 Args:
774 xs: Grounding of expression on which the aggregation has to be
775 performed.
776 weights: (default=None) The weights for the aggregation operator.
777 axis: (default=None) Axis along which the aggregation to be
778 computed. If None, the aggregation is computed along all the
779 axes.
780 keepdims: (default=False) Flag indicating whether the output has to
781 keep the same dimensions as the input after the aggregation.
782 mask: (default=None) Boolean mask for excluding values of 'xs'
783 from the aggregation. It is internally used for guarded
784 quantification. The mask must have the same shape of 'xs'.
785 `False` means exclusion, `True` means inclusion.
787 Returns:
788 Mean fuzzy aggregation of the formula.
789 """
790 # TODO rewrite the function in case the
791 # [#30678](https://github.com/jax-ml/jax/issues/30678) gets fixed.
792 # This complicated code is necessary, as jnp.average does not take a
793 # where-parameter and jnp.mean does not take weights-parameter.
794 if mask is not None and weights is not None:
795 xs = jnp.multiply(xs, mask)
796 weights = jnp.multiply(weights, mask)
797 return jnp.average(
798 xs, weights=weights, axis=axis, keepdims=keepdims
799 )
800 elif mask is None and weights is not None:
801 return jnp.average(
802 xs, weights=weights, axis=axis, keepdims=keepdims
803 )
804 else: # (weights is None)
805 return jnp.mean(xs, axis=axis, keepdims=keepdims, where=mask)
808class AggregPMean(AggregationOperator):
809 r"""[(Weighted) power mean (pmean) / generalized mean aggregation operator](https://en.wikipedia.org/wiki/Generalized_mean)
810 see [logictensornetwors](https://github.com/logictensornetworks/logictensornetworks/blob/master/tutorials/2-grounding_connectives.ipynb).
812 Generalized mean with $p \neq 0$:
813 $\left( 1/n * \sum_{i = 1}^n u_i^p \right)^{1/p}$
815 Generalized mean with $p = 0$:
816 $\left( \prod_{i=1}^n x_i \right)^{\frac{1}{n}}$
818 Weighted generalized mean with $p \neq 0$:
819 $\left( \frac{\sum_{i=1}^n w_i x_i^p}{\sum_{i=1}^n w_i} \right)^{1/p}$
821 Weighted generalized mean with $p = 0$:
822 $\left( \prod_{i=1}^n x_i^{w_i} \right)^{\frac{1}{\sum_{i=1}^n w_i}}$
824 `pMean` can be understood as a smooth-maximum that depends on the
825 hyper-parameter $p$:,
826 - $p \rightarrow -\infty$: the operator tends to $\min$,
827 - $p = -1$: harmonic mean,
828 - $p = 0$: geometric mean,
829 - $p = 1$: mean,
830 - $p = 2$: quadratic mean,
831 - $p = 3$: cubic mean,
832 - $p \rightarrow +\infty$: the operator tends to $\max$.
834 Attributes:
835 p: (default=2) Value of the parameter p.
836 stable: (default=True) Flag indicating whether to use the
837 [stable](../../stable.md) version of the operator or not.
839 Notes:
840 - pMean has <b>exploding gradients</b> for $a_1= \dotsc =a_n=0$.
841 - If not all values are $0$, these who has get a <b>vanishing gradient</b>.
842 """ # noqa: E501
844 def __init__(self, p: int = 2, stable: bool = True):
845 """Constructor.
847 Args:
848 p: (default=2) Value of the parameter p.
849 stable: (default=True) Flag indicating whether to use the
850 [stable](../../stable.md) version of the operator or not.
851 """
852 self.p = p
853 self.stable = stable
855 def __call__(
856 self,
857 xs: ArrayLike,
858 axis: Axis = None,
859 weights: ArrayLike | None = None,
860 keepdims: bool = False,
861 mask: ArrayLike | None = None,
862 p: int | None = None,
863 stable: bool | None = None,
864 ) -> Array:
865 """It applies the `pMean` aggregation operator to the given formula's
866 grounding on the selected dimensions.
868 Args:
869 xs: Grounding of expression on which the aggregation has to be
870 performed.
871 axis: (default=None) Axis along which the aggregation to be
872 computed. If None, the aggregation is computed along all the
873 axes.
874 weights: (default=None) The weights for the aggregation operator.
875 If `axis=None`, weights must have the same shape as `xs`.
876 If there is a shape defined, it must be an `int` or `list[int]`
877 containing exactly one `int` that must be the same as
878 `xs.shape[axis]`.
879 keepdims: (default=False) Flag indicating whether the output has to
880 keep the same dimensions as the input after the aggregation.
881 mask: (default=None) Boolean mask for excluding values of 'xs'
882 from the aggregation. It is internally used for guarded
883 quantification. The mask must have the same shape of 'xs'.
884 `False` means exclusion, `True` means inclusion.
885 p: (default=2) Value of the parameter p.
886 stable: (default=None) Flag indicating whether to use the
887 [stable](../../stable.md) version of the operator or not.
889 Returns:
890 `pMean` fuzzy aggregation of the formula.
891 """
892 # ArrayLike to Array
893 xs = jnp.asarray(xs)
894 # TODO rewrite if jax.numpy.mean gets a weights-parameter or
895 # if jax.numpy.average gets a where-parameter.
897 # Preparing input
898 p = self.p if p is None else p
899 stable = self.stable if stable is None else stable
900 if stable:
901 xs = not_zeros(xs)
903 # For specific p-values, call other aggregators.
904 if p == -1:
905 return AggregHMean()(
906 xs,
907 axis=axis,
908 weights=weights,
909 keepdims=keepdims,
910 mask=mask,
911 stable=stable,
912 )
913 if p == 0:
914 return AggregGMean()(
915 xs,
916 axis=axis,
917 weights=weights,
918 keepdims=keepdims,
919 mask=mask,
920 stable=stable,
921 )
922 if p == 1:
923 return AggregMean()(
924 xs, axis=axis, weights=weights, keepdims=keepdims, mask=mask
925 )
926 if p == 2:
927 return AggregQMean()(
928 xs,
929 axis=axis,
930 weights=weights,
931 keepdims=keepdims,
932 mask=mask,
933 stable=stable,
934 )
935 if p == 3:
936 return AggregCMean()(
937 xs,
938 axis=axis,
939 weights=weights,
940 keepdims=keepdims,
941 mask=mask,
942 stable=stable,
943 )
945 if weights is None:
946 return jnp.pow(
947 jnp.mean(
948 jnp.pow(xs, p), axis=axis, keepdims=keepdims, where=mask
949 ),
950 1 / p,
951 )
952 elif mask is None:
953 return jnp.pow(
954 jnp.average(
955 jnp.pow(xs, p),
956 axis=axis,
957 weights=weights,
958 keepdims=keepdims,
959 ),
960 1 / p,
961 )
962 else: # weights is not None and mask is not None:
963 weights = jnp.asarray(weights)
964 x_p = jnp.pow(xs, p) # may contain nan values if x_i == 0.
966 # Suppress UserWarnings from fuzzy_ops.AggregSum().
967 warnings.filterwarnings("ignore", category=UserWarning)
968 Sum = AggregSum()
970 upper = Sum(
971 x_p, axis=axis, weights=weights, keepdims=keepdims, mask=mask
972 )
973 if len(weights.shape) == 1:
974 weights = jnp.broadcast_to(weights, xs.shape)
975 lower = jnp.sum(weights, axis=axis, keepdims=keepdims, where=mask)
976 return jnp.pow(jnp.divide(upper, lower), 1 / p)
979class AggregHMean(AggregationOperator):
980 r"""[Harmonic mean fuzzy aggregation operator](https://en.wikipedia.org/wiki/Harmonic_mean).
982 The harmonic mean is the special case of the Power Mean with $p=-1$.
984 Harmonic mean:
985 $n / (\sum_{i = 1}^n \frac{1}{x_i}$
987 Weighted harmonic mean:
988 $\frac{\sum_{i=1}^n w_i}{\sum_{i=1}^n \frac{w_i}{x_i}}$
989 $= \left( \frac{\sum_{i=1}^n w_i x_i^{-1}}{\sum_{i=1}^n w_i} \right)^{-1}$
991 Attributes:
992 stable: (default=True) Flag indicating whether to use the
993 [stable](../../stable.md) version of the operator or not.
995 Notes:
996 - As we divide by $(\sum_{i = 1}^n \frac{1}{x_i})$, the values
997 $x_i$ <b>must not be $0$</b>.
998 """ # noqa: E501
1000 def __init__(self, stable: bool = True):
1001 """Constructor.
1003 Args:
1004 stable: (default=True) Flag indicating whether to use the
1005 [stable](../../stable.md) version of the operator or not.
1006 """
1007 self.stable = stable
1009 def __call__(
1010 self,
1011 xs: ArrayLike,
1012 axis: Axis = None,
1013 weights: ArrayLike | None = None,
1014 keepdims: bool = False,
1015 mask: ArrayLike | None = None,
1016 stable: bool | None = None,
1017 ) -> Array:
1018 """It applies the harmonic mean aggregation operator to the given
1019 expression's grounding on the selected dimensions.
1021 Args:
1022 xs: Grounding of expression on which the aggregation has to be
1023 performed.
1024 axis: (default=None) Axis along which the aggregation to be
1025 computed. If None, the aggregation is computed along all the
1026 axes.
1027 weights: (default=None) The weights for the aggregation operator.
1028 If `axis=None`, weights must have the same shape as `xs`.
1029 If there is a shape defined, it must be an `int` or `list[int]`
1030 containing exactly one `int` that must be the same as
1031 `xs.shape[axis]`.
1032 keepdims: (default=False) Flag indicating whether the output has to
1033 keep the same dimensions as the input after the aggregation.
1034 mask: (default=None) Boolean mask for excluding values of 'xs'
1035 from the aggregation. It is internally used for guarded
1036 quantification. The mask must have the same shape of 'xs'.
1037 `False` means exclusion, `True` means inclusion.
1038 stable: (default=None) Flag indicating whether to use the
1039 [stable](../../stable.md) version of the operator or not.
1041 Returns:
1042 harmonic mean aggregation applied to the expression.
1043 """
1044 # ArrayLike to Array
1045 xs = jnp.asarray(xs)
1046 # Preparing input
1047 stable = self.stable if stable is None else stable
1048 if stable: 1048 ↛ 1049line 1048 didn't jump to line 1049 because the condition on line 1048 was never true
1049 xs = not_zeros(xs)
1051 if weights is None:
1052 return jnp.reciprocal(
1053 jnp.mean(
1054 jnp.reciprocal(xs),
1055 axis=axis,
1056 keepdims=keepdims,
1057 where=mask,
1058 )
1059 )
1060 elif mask is None:
1061 return jnp.reciprocal(
1062 jnp.average(
1063 jnp.reciprocal(xs),
1064 axis=axis,
1065 weights=weights,
1066 keepdims=keepdims,
1067 )
1068 )
1069 else: # weights is not None and mask is not None:
1070 weights = jnp.asarray(weights)
1071 x_p = jnp.reciprocal(xs) # may contain nan values if x_i == 0.
1073 # Suppress UserWarnings from fuzzy_ops.AggregSum().
1074 warnings.filterwarnings("ignore", category=UserWarning)
1075 Sum = AggregSum()
1077 upper = Sum(
1078 x_p, axis=axis, weights=weights, keepdims=keepdims, mask=mask
1079 )
1080 if len(weights.shape) == 1:
1081 weights = jnp.broadcast_to(weights, xs.shape)
1082 lower = jnp.sum(weights, axis=axis, keepdims=keepdims, where=mask)
1083 return jnp.reciprocal(jnp.divide(upper, lower))
1086class AggregGMean(AggregationOperator):
1087 r"""[Geometric Mean fuzzy aggregation operator](https://en.wikipedia.org/wiki/Geometric_mean).
1088 Intended for conjunction as it approximates the minimum aggregation
1089 operator.
1091 Geometric mean:
1092 $(\prod_{i = 1}^n x_i)^{1/n} = \exp(\sum_{i = 1}^n \ln(x_i) / n)$
1094 Weighted geometric mean:
1095 $(\prod_{i = 1}^n x_i^{w_i})^{1/(\sum_{i=1}^n w_i)} = \exp(\sum_{i = 1}^n \ln(x_i) / n)$
1097 Attributes:
1098 stable: (default=True) Flag indicating whether to use the
1099 [stable](../../stable.md) version of the operator or not.
1100 """ # noqa: E501
1102 def __init__(self, stable: bool = True):
1103 """Constructor.
1105 Args:
1106 stable: (default=True) Flag indicating whether to use the
1107 [stable](../../stable.md) version of the operator or not.
1108 """
1109 self.stable = stable
1111 def __call__(
1112 self,
1113 xs: ArrayLike,
1114 axis: Axis = None,
1115 weights: ArrayLike | None = None,
1116 keepdims: bool = False,
1117 mask: ArrayLike | None = None,
1118 stable: bool | None = None,
1119 ) -> Array:
1120 """It applies the geometric mean aggregation operator to the given
1121 expression's grounding on the selected dimensions.
1123 Args:
1124 xs: Grounding of expression on which the aggregation has to be
1125 performed.
1126 axis: (default=None) Axis along which the aggregation to be
1127 computed. If None, the aggregation is computed along all the
1128 axes.
1129 weights: (default=None) The weights for the aggregation operator.
1130 If `axis=None`, weights must have the same shape as `xs`.
1131 If there is a shape defined, it must be an `int` or `list[int]`
1132 containing exactly one `int` that must be the same as
1133 `xs.shape[axis]`.
1134 keepdims: (default=False) Flag indicating whether the output has to
1135 keep the same dimensions as the input after the aggregation.
1136 mask: (default=None) Boolean mask for excluding values of 'xs'
1137 from the aggregation. It is internally used for guarded
1138 quantification. The mask must have the same shape of 'xs'.
1139 `False` means exclusion, `True` means inclusion.
1140 stable: (default=None) Flag indicating whether to use the
1141 [stable](../../stable.md) version of the operator or not.
1143 Returns:
1144 Geometric mean aggregation applied to the expression.
1146 Note:
1147 -jax.numpy.log computes the <b>natural logarithm</b>.
1148 """
1149 # ArrayLike to Array
1150 xs = jnp.asarray(xs)
1151 # preparing input
1152 stable = self.stable if stable is None else stable
1153 if stable: 1153 ↛ 1154line 1153 didn't jump to line 1154 because the condition on line 1153 was never true
1154 xs = not_zeros(xs)
1156 if weights is None:
1157 return jnp.exp(
1158 jnp.mean(jnp.log(xs), axis=axis, keepdims=keepdims, where=mask)
1159 )
1160 elif mask is None:
1161 return jnp.exp(
1162 jnp.average(
1163 jnp.log(xs), axis=axis, weights=weights, keepdims=keepdims
1164 )
1165 )
1166 else: # weights is not None and mask is not None:
1167 weights = jnp.asarray(weights)
1168 x_p = jnp.log(xs) # may contain nan values if x_i == 0.
1170 # Suppress UserWarnings from fuzzy_ops.AggregSum().
1171 warnings.filterwarnings("ignore", category=UserWarning)
1172 Sum = AggregSum()
1174 upper = Sum(
1175 x_p, axis=axis, weights=weights, keepdims=keepdims, mask=mask
1176 )
1177 if len(weights.shape) == 1:
1178 weights = jnp.broadcast_to(weights, xs.shape)
1179 lower = jnp.sum(weights, axis=axis, keepdims=keepdims, where=mask)
1180 return jnp.exp(jnp.divide(upper, lower))
1183class AggregQMean(AggregationOperator):
1184 r"""Quadratic mean or root mean square aggregation operator. Intended for
1185 disjunction as it approximates the maximum operator.
1187 Quadratic mean:
1188 $\sqrt{\sum_{i = 1}^n x_i^2 / n}$
1190 Weighted quadratic mean:
1191 $\sqrt{ \frac{\sum_{i=1}^n w_i x_i^2}{\sum_{i=1}^n w_i} }$
1192 """
1194 def __init__(self, stable: bool = True):
1195 """Constructor.
1197 Args:
1198 stable: (default=True) Flag indicating whether to use the
1199 [stable](../../stable.md) version of the operator or not.
1200 """
1201 self.stable = stable
1203 def __call__(
1204 self,
1205 xs: ArrayLike,
1206 axis: Axis = None,
1207 weights: ArrayLike | None = None,
1208 keepdims: bool = False,
1209 mask: ArrayLike | None = None,
1210 stable: bool | None = None,
1211 ) -> Array:
1212 """It applies the quadratic mean aggregation operator to the given
1213 expression's grounding on the selected dimensions.
1215 Args:
1216 xs: Grounding of expression on which the aggregation has to be
1217 performed.
1218 axis: (default=None) Axis along which the aggregation to be
1219 computed. If None, the aggregation is computed along all the
1220 axes.
1221 weights: (default=None) The weights for the aggregation operator.
1222 If `axis=None`, weights must have the same shape as `xs`.
1223 If there is a shape defined, it must be an `int` or `list[int]`
1224 containing exactly one `int` that must be the same as
1225 `xs.shape[axis]`.
1226 keepdims: (default=False) Flag indicating whether the output has to
1227 keep the same dimensions as the input after the aggregation.
1228 mask: (default=None) Boolean mask for excluding values of 'xs'
1229 from the aggregation. It is internally used for guarded
1230 quantification. The mask must have the same shape of 'xs'.
1231 `False` means exclusion, `True` means inclusion.
1232 stable: (default=None) Flag indicating whether to use the
1233 [stable](../../stable.md) version of the operator or not.
1235 Returns:
1236 Quadratic mean aggregation applied to the expression.
1237 """
1238 # ArrayLike to Array
1239 xs = jnp.asarray(xs)
1240 # Preparing input
1241 stable = self.stable if stable is None else stable
1242 if stable:
1243 xs = not_zeros(xs)
1245 if weights is None:
1246 return jnp.sqrt(
1247 jnp.mean(
1248 jnp.square(xs), axis=axis, keepdims=keepdims, where=mask
1249 )
1250 )
1251 elif mask is None:
1252 return jnp.sqrt(
1253 jnp.average(
1254 jnp.square(xs),
1255 axis=axis,
1256 weights=weights,
1257 keepdims=keepdims,
1258 )
1259 )
1260 else: # weights is not None and mask is not None:
1261 weights = jnp.asarray(weights)
1262 x_p = jnp.square(xs) # may contain nan values if x_i == 0.
1264 # Suppress UserWarnings from fuzzy_ops.AggregSum().
1265 warnings.filterwarnings("ignore", category=UserWarning)
1266 Sum = AggregSum()
1268 upper = Sum(
1269 x_p, axis=axis, weights=weights, keepdims=keepdims, mask=mask
1270 )
1271 if len(weights.shape) == 1:
1272 weights = jnp.broadcast_to(weights, xs.shape)
1273 lower = jnp.sum(weights, axis=axis, keepdims=keepdims, where=mask)
1274 return jnp.sqrt(jnp.divide(upper, lower))
1277class AggregCMean(AggregationOperator):
1278 r"""Cubic mean or root mean square aggregation operator. Intended for
1279 disjunction as it approximates the maximum operator.
1281 Cubic mean:
1282 $\left( \sum_{i = 1}^n x_i^3 / n \right)^{\frac{1}{3}}$
1284 Weighted cubic mean:
1285 $\left( \frac{\sum_{i=1}^n w_i x_i^3}{\sum_{i=1}^n w_i} \right)^{1/3}$
1286 """ # noqa: E501
1288 def __init__(self, stable: bool = True):
1289 """Constructor.
1291 Args:
1292 stable: (default=True) Flag indicating whether to use the
1293 [stable](../../stable.md) version of the operator or not.
1294 """
1295 self.stable = stable
1297 def __call__(
1298 self,
1299 xs: ArrayLike,
1300 axis: Axis = None,
1301 weights: ArrayLike | None = None,
1302 keepdims: bool = False,
1303 mask: ArrayLike | None = None,
1304 stable: bool | None = None,
1305 ) -> Array:
1306 """It applies the quadratic mean aggregation operator to the given
1307 expression's grounding on the selected dimensions.
1309 Args:
1310 xs: Grounding of expression on which the aggregation has to be
1311 performed.
1312 axis: (default=None) Axis along which the aggregation to be
1313 computed. If None, the aggregation is computed along all the
1314 axes.
1315 weights: (default=None) The weights for the aggregation operator.
1316 If `axis=None`, weights must have the same shape as `xs`.
1317 If there is a shape defined, it must be an `int` or `list[int]`
1318 containing exactly one `int` that must be the same as
1319 `xs.shape[axis]`.
1320 keepdims: (default=False) Flag indicating whether the output has to
1321 keep the same dimensions as the input after the aggregation.
1322 mask: (default=None) Boolean mask for excluding values of 'xs'
1323 from the aggregation. It is internally used for guarded
1324 quantification. The mask must have the same shape of 'xs'.
1325 `False` means exclusion, `True` means inclusion.
1326 stable: (default=None) Flag indicating whether to use the
1327 [stable](../../stable.md) version of the operator or not.
1329 Returns:
1330 Cubic mean aggregation applied to the expression.
1331 """
1332 # ArrayLike to Array
1333 xs = jnp.asarray(xs)
1334 # Preparing input
1335 stable = self.stable if stable is None else stable
1336 if stable: 1336 ↛ 1337line 1336 didn't jump to line 1337 because the condition on line 1336 was never true
1337 xs = not_zeros(xs)
1339 if weights is None:
1340 return jnp.cbrt(
1341 jnp.mean(
1342 jnp.pow(xs, 3), axis=axis, keepdims=keepdims, where=mask
1343 )
1344 )
1345 elif mask is None:
1346 return jnp.cbrt(
1347 jnp.average(
1348 jnp.pow(xs, 3),
1349 axis=axis,
1350 weights=weights,
1351 keepdims=keepdims,
1352 )
1353 )
1354 else: # weights is not None and mask is not None:
1355 weights = jnp.asarray(weights)
1356 x_p = jnp.pow(xs, 3) # may contain nan values if x_i == 0.
1358 # Suppress UserWarnings from fuzzy_ops.AggregSum().
1359 warnings.filterwarnings("ignore", category=UserWarning)
1360 Sum = AggregSum()
1362 upper = Sum(
1363 x_p, axis=axis, weights=weights, keepdims=keepdims, mask=mask
1364 )
1365 if len(weights.shape) == 1:
1366 weights = jnp.broadcast_to(weights, xs.shape)
1367 lower = jnp.sum(weights, axis=axis, keepdims=keepdims, where=mask)
1368 return jnp.cbrt(jnp.divide(upper, lower))
1371class AggregPMeanError(AggregationOperator):
1372 r"""`pMeanError` fuzzy aggregation operator.
1374 $A_{pME}(x_1, \dots, x_n) = 1 - (\frac{1}{n} \sum_{i = 1}^n (1 - x_i)^p)^{\frac{1}{p}}$
1376 Attributes:
1377 p: (default=2) Value of the parameter p.
1378 stable: (default=True) Flag indicating whether to use the
1379 [stable](../../stable.md) version of the operator or not.
1381 Notes:
1382 - pMeanError has <b>exploding gradients</b> for $a_1= \dotsc =a_n=1$.
1383 - If not all values are $1$, these who has get a <b>vanishing gradient</b>.
1384 """ # noqa: E501
1386 def __init__(self, p: int = 2, stable: bool = True):
1387 """Constructor.
1389 Args:
1390 p: (default=2) Value of the parameter p.
1391 stable: (default=True) Flag indicating whether to use the
1392 [stable](../../stable.md) version of the operator or not.
1393 """
1394 self.p = p
1395 self.stable = stable
1397 def __call__(
1398 self,
1399 xs: ArrayLike,
1400 axis: Axis = None,
1401 weights: ArrayLike | None = None,
1402 keepdims: bool = False,
1403 mask: ArrayLike | None = None,
1404 p: int | None = None,
1405 stable: bool | None = None,
1406 ) -> Array:
1407 """It applies the `pMeanError` aggregation operator to the given
1408 formula's grounding on the selected dimensions.
1410 Args:
1411 xs: Grounding of expression on which the aggregation has to be
1412 performed.
1413 axis: (default=None) Axis along which the aggregation to be
1414 computed. If None, the aggregation is computed along all the
1415 axes.
1416 weights: (default=None) The weights for the aggregation operator.
1417 If `axis=None`, weights must have the same shape as `xs`.
1418 If there is a shape defined, it must be an `int` or `list[int]`
1419 containing exactly one `int` that must be the same as
1420 `xs.shape[axis]`.
1421 keepdims: (default=False) Flag indicating whether the output has to
1422 keep the same dimensions as the input after the aggregation.
1423 mask: (default=None) Boolean mask for excluding values of 'xs'
1424 from the aggregation. It is internally used for guarded
1425 quantification. The mask must have the same shape of 'xs'.
1426 `False` means exclusion, `True` means inclusion.
1427 p: (default=2) Value of the parameter p.
1428 stable: (default=None) Flag indicating whether to use the
1429 [stable](../../stable.md) version of the operator or not.
1431 Returns:
1432 `pMeanError` fuzzy aggregation of the formula.
1433 """
1434 # ArrayLike to Array
1435 xs = jnp.asarray(xs)
1436 # Preparing input
1437 p = self.p if p is None else p
1438 stable = self.stable if stable is None else stable
1439 if stable: 1439 ↛ 1442line 1439 didn't jump to line 1442 because the condition on line 1439 was always true
1440 xs = not_ones(xs)
1442 return 1.0 - AggregPMean()(
1443 1.0 - xs,
1444 axis=axis,
1445 weights=weights,
1446 keepdims=keepdims,
1447 mask=mask,
1448 p=p,
1449 stable=stable,
1450 )
1453class AggregMax(AggregationOperator):
1454 r"""Max fuzzy aggregation operator. Intended for disjunction.
1456 $A_{T_{M}}(x_1, \dots, x_n) = \max(x_1, \dots, x_n)$
1458 Notes:
1459 - This aggregator has a <b>single-passing gradient</b> for the maximum
1460 value. If $n$ values attain the maximum, they get the gradients
1461 $\frac{1}{n}$.
1462 """
1464 def __call__(
1465 self,
1466 xs: ArrayLike,
1467 axis: Axis = None,
1468 keepdims: bool = False,
1469 mask: ArrayLike | None = None,
1470 ) -> Array:
1471 """It applies the max fuzzy aggregation operator to the given
1472 formula's grounding on the selected dimensions.
1474 Args:
1475 xs: Grounding of expression on which the aggregation has to be
1476 performed.
1477 axis: (default=None) Axis along which the aggregation to be
1478 computed. If None, the aggregation is computed along all the
1479 axes.
1480 keepdims: (default=False) Flag indicating whether the output has to
1481 keep the same dimensions as the input after the aggregation.
1482 mask: (default=None) Boolean mask for excluding values of 'xs'
1483 from the aggregation. It is internally used for guarded
1484 quantification. The mask must have the same shape of 'xs'.
1485 `False` means exclusion, `True` means inclusion.
1487 Returns:
1488 Max fuzzy aggregation of the formula.
1489 """
1490 # ArrayLike to Array
1491 xs = jnp.asarray(xs)
1493 return jnp.max(
1494 xs, axis=axis, keepdims=keepdims, where=mask, initial=0.0
1495 )
1498class AggregBoltzmann(AggregationOperator):
1499 r"""[Boltzmann fuzzy aggregation operator](https://en.wikipedia.org/wiki/Smooth_maximum).
1500 This is intended to be used as an disjunction operator as it approximates
1501 the maximum aggregation.
1503 $S_\alpha(x_1,...,x_n)= \sum_{i = 1}^n x_i * \exp(\alpha x_i) / \sum_{i = 1}^n exp(\alpha x_i)$
1505 Attributes:
1506 alpha: $\alpha$ parameter for
1507 $S_\alpha(x_1,...,x_n)$.
1509 Notes:
1510 - $S_\alpha \rightarrow \max$ as $\alpha \rightarrow \infty$.
1511 - $S_0$ is the arithmetic mean of its inputs.
1512 - $S_\alpha \rightarrow \min$ as $\alpha \rightarrow -\infty$.
1513 - With zero-only mask, the function will return <b>`nan`</b> as we divide
1514 by a sum, that will be $0$.
1515 """ # noqa: E501
1517 def __init__(self, alpha: float):
1518 r"""Constructor.
1520 Args:
1521 alpha: $\alpha$ parameter for
1522 $S_\alpha(x_1,...,x_n)$.
1523 """
1524 self.alpha = alpha
1526 def __call__(
1527 self,
1528 xs: ArrayLike,
1529 axis: Axis = None,
1530 keepdims: bool = False,
1531 mask: ArrayLike | None = None,
1532 ) -> Array:
1533 """It applies the Boltzmann aggregation operator to the given
1534 expression's grounding on the selected dimensions.
1536 Args:
1537 xs: Grounding of expression on which the aggregation has to be
1538 performed.
1539 axis: (default=None) Axis along which the aggregation to be
1540 computed. If None, the aggregation is computed along all the
1541 axes.
1542 keepdims: (default=False) Flag indicating whether the output has to
1543 keep the same dimensions as the input after the aggregation.
1544 mask: (default=None) Boolean mask for excluding values of 'xs'
1545 from the aggregation. It is internally used for guarded
1546 quantification. The mask must have the same shape of 'xs'.
1547 `False` means exclusion, `True` means inclusion.
1549 Returns:
1550 Boltzmann aggregation applied to the expression.
1551 """
1552 # ArrayLike to Array
1553 xs = jnp.asarray(xs)
1555 exp_expr = jnp.exp(jnp.multiply(self.alpha, xs))
1556 return jnp.divide(
1557 jnp.sum(
1558 jnp.multiply(xs, exp_expr),
1559 axis=axis,
1560 keepdims=keepdims,
1561 where=mask,
1562 ),
1563 jnp.sum(exp_expr, axis=axis, keepdims=keepdims, where=mask),
1564 )
1567class AggregLogSumExp(AggregationOperator):
1568 r"""[LogSumExp operator](https://en.wikipedia.org/wiki/LogSumExp)
1569 Intended for disjunction as it approximates the maximum aggregation.
1571 $LSE_\alpha(x_1,\dotsc,x_n) = (1/\alpha) \log \sum_{i = 1}^n \exp (\alpha x_i)$
1573 Attributes:
1574 alpha: $\alpha$ parameter for
1575 $LSE_\alpha(x_1,\dotsc,x_n)$.
1577 Notes:
1578 - With zero-only mask, the function will return <b>`-np.infty`</b> as we
1579 take the logarithm of an empty sum, i.e. of $0$.
1580 """ # noqa: E501
1582 def __init__(self, alpha: float):
1583 r"""Constructor.
1585 Args:
1586 alpha: $\alpha$ parameter for
1587 $LSE_\alpha(x_1,...,x_n)$.
1588 """
1589 self.alpha = alpha
1591 def __call__(
1592 self,
1593 xs: ArrayLike,
1594 axis: Axis = None,
1595 keepdims: bool = False,
1596 mask: ArrayLike | None = None,
1597 ) -> Array:
1598 """It applies the LogSumExp aggregation operator to the given
1599 expression's grounding on the selected dimensions.
1601 Args:
1602 xs: Grounding of expression on which the aggregation has to be
1603 performed.
1604 axis: (default=None) Axis along which the aggregation to be
1605 computed. If None, the aggregation is computed along all the
1606 axes.
1607 keepdims: (default=False) Flag indicating whether the output has to
1608 keep the same dimensions as the input after the aggregation.
1609 mask: (default=None) Boolean mask for excluding values of 'xs'
1610 from the aggregation. It is internally used for guarded
1611 quantification. The mask must have the same shape of 'xs'.
1612 `False` means exclusion, `True` means inclusion.
1614 Returns:
1615 LogSumExp aggregation applied to the expression.
1616 """
1617 # ArrayLike to Array
1618 xs = jnp.asarray(xs)
1620 return jnp.multiply(
1621 jnp.divide(1, self.alpha),
1622 jnp.log(
1623 jnp.sum(
1624 jnp.exp(jnp.multiply(self.alpha, xs)),
1625 axis=axis,
1626 keepdims=keepdims,
1627 where=mask,
1628 )
1629 ),
1630 )
1633class AggregMellowmax(AggregationOperator):
1634 r"""[Mellowmax fuzzy aggregation operator](https://en.wikipedia.org/wiki/Smooth_maximum#Mellowmax).
1635 This is intended to be used as an disjunction operator as it approximates
1636 the maximum aggregation.
1638 $mm_\alpha(x_1,...,x_n) = (1/\alpha) \log \frac{1}{n} \sum_{i = 1}^n \exp (\alpha x_i)$
1640 Attributes:
1641 alpha: $\alpha$ parameter for
1642 $mm_\alpha(x_1,...,x_n)$.
1645 Notes:
1646 - The result is <b>undefined</b> if we set $\alpha=0$.
1647 - $mm_\alpha \rightarrow \max$ as $\alpha \rightarrow \infty$.
1648 - $mm_\alpha \rightarrow 0$ is the arithmetic mean of its inputs.
1649 - $mm_\alpha \rightarrow \min$ as $\alpha \rightarrow -\infty$.
1650 - With zero-only mask, the function will return <b>`nan`</b> as we divide
1651 by a $n=0$.
1652 """ # noqa: E501
1654 def __init__(self, alpha: float):
1655 r"""Constructor.
1657 Args:
1658 alpha: $\alpha$ parameter for
1659 $mm_\alpha(x_1,\dotsc,x_n)$.
1660 """
1661 self.alpha = alpha
1663 def __call__(
1664 self,
1665 xs: ArrayLike,
1666 axis: Axis = None,
1667 keepdims: bool = False,
1668 mask: ArrayLike | None = None,
1669 ) -> Array:
1670 """It applies the Mellowmax aggregation operator to the given
1671 expression's grounding on the selected dimensions.
1673 Args:
1674 xs: Grounding of expression on which the aggregation has to be
1675 performed.
1676 axis: (default=None) Axis along which the aggregation to be
1677 computed. If None, the aggregation is computed along all the
1678 axes.
1679 keepdims: (default=False) Flag indicating whether the output has to
1680 keep the same dimensions as the input after the aggregation.
1681 mask: (default=None) Boolean mask for excluding values of 'xs'
1682 from the aggregation. It is internally used for guarded
1683 quantification. The mask must have the same shape of 'xs'.
1684 `False` means exclusion, `True` means inclusion.
1686 Returns:
1687 Mellowmax aggregation applied to the expression.
1688 """
1689 # ArrayLike to Array
1690 xs = jnp.asarray(xs)
1692 if mask is not None: 1692 ↛ 1695line 1692 didn't jump to line 1695 because the condition on line 1692 was always true
1693 n = jnp.sum(mask, axis=axis)
1694 else:
1695 n = jnp.asarray(jnp.size(xs, axis=axis))
1696 return jnp.multiply(
1697 jnp.reciprocal(self.alpha),
1698 jnp.log(
1699 jnp.multiply(
1700 jnp.reciprocal(n),
1701 jnp.sum(
1702 jnp.exp(jnp.multiply(self.alpha, xs)),
1703 axis=axis,
1704 keepdims=keepdims,
1705 where=mask,
1706 ),
1707 )
1708 ),
1709 )
1712class AggregProd(AggregationOperator):
1713 r"""(Weighted) Product fuzzy aggregation operator.
1714 Intended for conjunction.
1716 Product:
1717 $\prod_{i = 1}^n x_i$
1719 Weighted Product:
1720 $\prod_{i = 1}^n x_i^{w_i}$
1721 $= \exp( \sum_{i = 1}^n w_i \ln x_i )$
1723 Attributes:
1724 stable: (default=True) Flag indicating whether to use the
1725 [stable](../../stable.md) version of the operator or not.
1727 Notes:
1728 - prod has <b>vanishing gradients</b> if at least two values of $a_i$
1729 are $0$.
1730 If exactly one value is $0$, we have a <b>single-passing gradient</b>
1731 for that value.
1732 """
1734 def __init__(self, stable: bool = True):
1735 """Constructor.
1737 Args:
1738 stable: (default=True) Flag indicating whether to use the
1739 [stable](../../stable.md) version of the operator or not.
1740 """
1741 self.stable = stable
1743 def __call__(
1744 self,
1745 xs: ArrayLike,
1746 axis: Axis = None,
1747 weights: ArrayLike | None = None,
1748 keepdims: bool = False,
1749 mask: ArrayLike | None = None,
1750 stable: bool | None = None,
1751 ) -> Array:
1752 """It applies the prod fuzzy aggregation operator to the given
1753 expression's grounding on the selected dimensions.
1755 Args:
1756 xs: Grounding of expression on which the aggregation has to be
1757 performed.
1758 axis: (default=None) Axis along which the aggregation to be
1759 computed. If None, the aggregation is computed along all the
1760 axes.
1761 weights: (default=None) The weights for the aggregation operator.
1762 If `axis=None`, weights must have the same shape as `xs`.
1763 If there is a shape defined, it must be an `int` or `list[int]`
1764 containing exactly one `int` that must be the same as
1765 `xs.shape[axis]`.
1766 keepdims: (default=False) Flag indicating whether the output has to
1767 keep the same dimensions as the input after the aggregation.
1768 mask: (default=None) Boolean mask for excluding values of 'xs'
1769 from the aggregation. It is internally used for guarded
1770 quantification. The mask must have the same shape of 'xs'.
1771 `False` means exclusion, `True` means inclusion.
1772 stable: (default=None) Flag indicating whether to use the
1773 [stable](../../stable.md) version of the operator or not.
1775 Returns:
1776 Prod fuzzy aggregation applied to the expression.
1777 """
1778 # ArrayLike to Array
1779 xs = jnp.asarray(xs)
1781 # Prepare input
1782 stable = self.stable if stable is None else stable
1783 if stable:
1784 xs = not_zeros(xs)
1786 # Suppress UserWarnings from fuzzy_ops.AggregSum().
1787 warnings.filterwarnings("ignore", category=UserWarning)
1788 Sum = AggregSum()
1790 return jnp.exp(
1791 Sum(
1792 xs=jnp.log(xs),
1793 axis=axis,
1794 weights=weights,
1795 keepdims=keepdims,
1796 mask=mask,
1797 )
1798 )
1801class AggregProbSum(AggregationOperator):
1802 r"""Probabilistic sum aggregation operator. Intended for disjunction as it
1803 is the inverse of [AggregProd][ltnjax.fuzzy_ops.AggregProd].
1805 $1-\prod_{i = 1}^n (1-x_i)$
1807 Attributes:
1808 stable: (default=True) Flag indicating whether to use the
1809 [stable](../../stable.md) version of the operator or not.
1811 Notes:
1812 - ProbSum has <b>vanishing gradients</b> for the case that at least two
1813 values $a_i$ are $1.$.
1814 For the case that exactly one $a_i$ equals $1.$, we have a
1815 <b>single-passing gradient</b> for this value.
1816 """
1818 def __init__(self, stable: bool = True):
1819 """Constructor.
1821 Args:
1822 stable: (default=True) Flag indicating whether to use the
1823 [stable](../../stable.md) version of the operator or not.
1824 """
1825 self.stable = stable
1827 def __call__(
1828 self,
1829 xs: ArrayLike,
1830 axis: Axis = None,
1831 weights: ArrayLike | None = None,
1832 keepdims: bool = False,
1833 mask: ArrayLike | None = None,
1834 stable: bool | None = None,
1835 ) -> Array:
1836 """It applies the ProbSum fuzzy aggregation operator to the given
1837 expression's grounding on the selected dimensions.
1839 Args:
1840 xs: Grounding of expression on which the aggregation has to be
1841 performed.
1842 axis: (default=None) Axis along which the aggregation to be
1843 computed. If None, the aggregation is computed along all the
1844 axes.
1845 weights: (default=None) The weights for the aggregation operator.
1846 If `axis=None`, weights must have the same shape as `xs`.
1847 If there is a shape defined, it must be an `int` or `list[int]`
1848 containing exactly one `int` that must be the same as
1849 `xs.shape[axis]`.
1850 keepdims: (default=False) Flag indicating whether the output has to
1851 keep the same dimensions as the input after the aggregation.
1852 mask: (default=None) Boolean mask for excluding values of 'xs'
1853 from the aggregation. It is internally used for guarded
1854 quantification. The mask must have the same shape of 'xs'.
1855 `False` means exclusion, `True` means inclusion.
1856 stable: (default=None) Flag indicating whether to use the
1857 [stable](../../stable.md) version of the operator or not.
1859 Returns:
1860 ProbSum fuzzy aggregation applied to the expression.
1861 """
1862 # ArrayLike to Array
1863 xs = jnp.asarray(xs)
1865 # Prepare input
1866 stable = self.stable if stable is None else stable
1867 if stable:
1868 xs = not_ones(xs)
1870 return 1.0 - AggregProd()(
1871 1.0 - xs,
1872 axis=axis,
1873 weights=weights,
1874 keepdims=keepdims,
1875 mask=mask,
1876 stable=stable,
1877 )
1880class AggregLogProd(AggregationOperator):
1881 r"""Log-product aggregation operator.
1883 $\sum_{i = 1}^n \log x_i$
1885 Attributes:
1886 stable: (default=True) Flag indicating whether to use the
1887 [stable](../../stable.md) version of the operator or not.
1889 Notes:
1890 - The values in `xs` should <b>not be $0$</b> as we take its
1891 logarithm.
1892 """
1894 def __init__(self, stable: bool = True):
1895 """Constructor.
1897 Args:
1898 stable: (default=True) Flag indicating whether to use the
1899 [stable](../../stable.md) version of the operator or not.
1900 """
1901 warn(
1902 "`Aggreg_LogProd` outputs values out of the truth value range"
1903 "[0,1]. "
1904 "Its usage with other connectives could be compromised. "
1905 "Use it carefully.",
1906 UserWarning,
1907 stacklevel=1,
1908 )
1909 self.stable = stable
1911 def __call__(
1912 self,
1913 xs: ArrayLike,
1914 axis: Axis = None,
1915 keepdims: bool = False,
1916 mask: ArrayLike | None = None,
1917 stable: bool | None = None,
1918 ) -> Array:
1919 """It applies the LogProd aggregation operator to the given
1920 expression's grounding on the selected dimensions.
1922 Args:
1923 xs: Grounding of expression on which the aggregation has to be
1924 performed.
1925 axis: (default=None) Axis along which the aggregation to be
1926 computed. If None, the aggregation is computed along all the
1927 axes.
1928 keepdims: (default=False) Flag indicating whether the output has to
1929 keep the same dimensions as the input after the aggregation.
1930 mask: (default=None) Boolean mask for excluding values of 'xs'
1931 from the aggregation. It is internally used for guarded
1932 quantification. The mask must have the same shape of 'xs'.
1933 `False` means exclusion, `True` means inclusion.
1934 stable: (default=None) Flag indicating whether to use the
1935 [stable](../../stable.md) version of the operator or not.
1937 Returns:
1938 LogProd aggregation applied to the expression.
1939 """
1940 # ArrayLike to Array
1941 xs = jnp.asarray(xs)
1943 if mask is not None:
1944 xs = jnp.where(xs, mask, 0.0)
1945 stable = self.stable if stable is None else stable
1946 if stable:
1947 xs = not_zeros(xs)
1948 return jnp.sum(jnp.log(xs), axis=axis, keepdims=keepdims)
1951AggregSumLog = AggregLogProd
1954class AggregLukMax(AggregationOperator):
1955 r"""Lukasiewicz fuzzy maximum operator.
1956 This is intended to be used as an conjunction operator.
1958 $\max(\sum_{i = 1}^n x_i - n + 1, 0)$
1960 Notes:
1961 - Luk_Max has <b>vanishing gradiens</b> for
1962 $\sum_{i = 1}^n x_i - n + 1 < 0$.
1963 If $\sum_{i = 1}^n x_i - n + 1 = 0$, all gradients are $0.5$.
1964 """
1966 def __call__(
1967 self,
1968 xs: ArrayLike,
1969 axis: Axis = None,
1970 keepdims: bool = False,
1971 mask: ArrayLike | None = None,
1972 ) -> Array:
1973 """It applies the Lukasiewicz fuzzy disjunction aggregation operator
1974 to the given expression's grounding on the selected dimensions.
1976 Args:
1977 xs: Grounding of expression on which the aggregation has to be
1978 performed.
1979 axis: (default=None) Axis along which the aggregation to be
1980 computed. If None, the aggregation is computed along all the
1981 axes.
1982 keepdims: (default=False) Flag indicating whether the output has to
1983 keep the same dimensions as the input after the aggregation.
1984 mask: (default=None) Boolean mask for excluding values of 'xs'
1985 from the aggregation. It is internally used for guarded
1986 quantification. The mask must have the same shape of 'xs'.
1987 `False` means exclusion, `True` means inclusion.
1989 Returns:
1990 Lukasiewicz fuzzy disjunction aggregation applied to the
1991 expression.
1992 """
1993 # ArrayLike to Array
1994 xs = jnp.asarray(xs)
1996 if mask is not None:
1997 n = jnp.sum(mask, axis=axis)
1998 else:
1999 n = jnp.asarray(jnp.size(xs, axis=axis))
2000 return jnp.maximum(
2001 jnp.sum(xs, axis=axis, keepdims=keepdims, where=mask) - n + 1, 0
2002 )
2005class AggregLukMin(AggregationOperator):
2006 r"""Lukasiewicz fuzzy minimum aggregation operator. Intended for
2007 disjunction.
2009 $\min(\sum_{i = 1}^n x_i, 1)$
2011 Notes:
2012 - We have <b>vanishing gradients</b> for the case
2013 $\sum_{i = 1}^n x_i > 1$.
2014 For the case that $\sum_{i = 1}^n x_i = 1$, all gradients are
2015 $0.5$.
2016 """
2018 def __call__(
2019 self,
2020 xs: ArrayLike,
2021 axis: Axis = None,
2022 keepdims: bool = False,
2023 mask: ArrayLike | None = None,
2024 ) -> Array:
2025 """It applies the Lukasiewicz fuzzy conjunction aggregation operator
2026 to the given expression's grounding on the selected dimensions.
2028 Args:
2029 xs: Grounding of expression on which the aggregation has to be
2030 performed.
2031 axis: (default=None) Axis along which the aggregation to be
2032 computed. If None, the aggregation is computed along all the
2033 axes.
2034 keepdims: (default=False) Flag indicating whether the output has to
2035 keep the same dimensions as the input after the aggregation.
2036 mask: (default=None) Boolean mask for excluding values of 'xs'
2037 from the aggregation. It is internally used for guarded
2038 quantification. The mask must have the same shape of 'xs'.
2039 `False` means exclusion, `True` means inclusion.
2041 Returns:
2042 Lukasiewicz fuzzy aggregation applied to the expression.
2043 """
2044 # ArrayLike to Array
2045 xs = jnp.asarray(xs)
2047 return jnp.minimum(
2048 1.0, jnp.sum(xs, axis=axis, keepdims=keepdims, where=mask)
2049 )
2052class AggregYager2(AggregationOperator):
2053 r"""yager2 fuzzy aggregation operation. Intended for disjunction
2054 aggregation operator as it approximates :class.`Aggreg_Luk_Min`.
2056 $\min(1, \sqrt(\sum_{i = 1}^n x_i^2))$
2058 Notes:
2059 - We have <b>vanishing gradients</b> for the case
2060 $\sqrt(\sum_{i = 1}^n x_i^2) > 1$.
2061 For the case that $\sqrt(\sum_{i = 1}^n x_i^2) = 1$, all gradients
2062 are $0.5$.
2063 """
2065 def __call__(
2066 self,
2067 xs: ArrayLike,
2068 axis: Axis = None,
2069 keepdims: bool = False,
2070 mask: ArrayLike | None = None,
2071 ) -> Array:
2072 """It applies the Yager2 fuzzy aggregation operator to the given
2073 expression's grounding on the selected dimensions.
2075 Args:
2076 xs: Grounding of expression on which the aggregation has to be
2077 performed.
2078 axis: (default=None) Axis along which the aggregation to be
2079 computed. If None, the aggregation is computed along all the
2080 axes.
2081 keepdims: (default=False) Flag indicating whether the output has to
2082 keep the same dimensions as the input after the aggregation.
2083 mask: (default=None) Boolean mask for excluding values of 'xs'
2084 from the aggregation. It is internally used for guarded
2085 quantification. The mask must have the same shape of 'xs'.
2086 `False` means exclusion, `True` means inclusion.
2088 Returns:
2089 Yager2 fuzzy aggregation applied to the expression.
2090 """
2091 # ArrayLike to Array
2092 xs = jnp.asarray(xs)
2094 return jnp.minimum(
2095 1.0,
2096 jnp.sqrt(
2097 jnp.sum(
2098 jnp.square(xs), axis=axis, keepdims=keepdims, where=mask
2099 )
2100 ),
2101 )
2104"""Special aggregators"""
2107class AggregInverted(AggregationOperator):
2108 r"""Inversion operator, that inverts conjunction operators to disjunction
2109 operators and vise verca. The given `aggreg_op` is inverted.
2111 For that, we make use of the fact
2112 $\bigvee(x1,x2,x3) = \neg(\bigwedge(\neg(x1), \neg(x2), \neg(x3)))$
2114 For example we have
2115 `Aggreg_Inverted(Aggreg_pMeanError(p=2))) == Aggreg_pMean(p=2))`
2117 Attributes:
2118 aggreg_op: The aggregation operator that is inverted.
2119 """ # noqa: E501
2121 def __init__(self, aggreg_op: AggregationOperator):
2122 """Constructor.
2124 Args:
2125 aggreg_op: The aggregation operator that is inverted.
2126 """
2127 self.aggreg_op = aggreg_op
2129 def __call__(
2130 self,
2131 xs: ArrayLike,
2132 axis: Axis = None,
2133 keepdims: bool = False,
2134 mask: ArrayLike | None = None,
2135 **kwargs: Any,
2136 ) -> Array:
2137 """It applies the inverted aggregation operator `aggreg_op` to the
2138 given expression's grounding on the selected dimensions.
2140 Args:
2141 xs: Grounding of expression on which the aggregation has to be
2142 performed.
2143 axis: (default=None) Axis along which the aggregation to be
2144 computed. If None, the aggregation is computed along all the
2145 axes.
2146 keepdims: (default=False) Flag indicating whether the output has to
2147 keep the same dimensions as the input after the aggregation.
2148 mask: (default=None) Boolean mask for excluding values of 'xs'
2149 from the aggregation. It is internally used for guarded
2150 quantification. The mask must have the same shape of 'xs'.
2151 `False` means exclusion, `True` means inclusion.
2152 kwargs: Further arguments to pass to `connective_op`.
2154 Returns:
2155 Inverted aggregation applied to the expression.
2156 """
2157 # ArrayLike to Array
2158 xs = jnp.asarray(xs)
2160 return 1.0 - self.aggreg_op(
2161 1.0 - xs, axis=axis, keepdims=keepdims, mask=mask, **kwargs
2162 )
2165class AggregSum(AggregationOperator):
2166 r"""(Weighted) Sum aggregation operator.
2168 $\sum_{i = 1}^n w_i x_n$
2169 """
2171 def __init__(self):
2172 """Constructor."""
2173 warn(
2174 "`Aggreg_Sum` outputs values out of the truth value range [0,1]. "
2175 "Its usage with other connectives could be compromised. "
2176 "Use it carefully.",
2177 UserWarning,
2178 stacklevel=1,
2179 )
2181 def __call__(
2182 self,
2183 xs: ArrayLike,
2184 axis: Axis = None,
2185 weights: ArrayLike | None = None,
2186 keepdims: bool = False,
2187 mask: ArrayLike | None = None,
2188 ) -> Array:
2189 """It applies the (weighted) sum operator to the given expression's
2190 grounding on the selected dimensions.
2192 Args:
2193 xs: Grounding of expression on which the aggregation has to be
2194 performed.
2195 axis: (default=None) Axis along which the aggregation to be
2196 computed. If None, the aggregation is computed along all the
2197 axes.
2198 weights: (default=None) The weights for the aggregation operator.
2199 If `axis=None`, weights must have the same shape as `xs`.
2200 If there is a shape defined, it must be an `int` or `list[int]`
2201 containing exactly one `int` that must be the same as
2202 `xs.shape[axis]`.
2203 keepdims: (default=False) Flag indicating whether the output has to
2204 keep the same dimensions as the input after the aggregation.
2205 mask: (default=None) Boolean mask for excluding values of 'xs'
2206 from the aggregation. It is internally used for guarded
2207 quantification. The mask must have the same shape of 'xs'.
2208 `False` means exclusion, `True` means inclusion.
2210 Returns:
2211 (Weighted) sum aggregator applied to the expression.
2213 Raises:
2214 ValueError: If `weights` are given but not compatible with `xs`.
2215 """
2216 # ArrayLike to Array
2217 xs = jnp.asarray(xs)
2219 # axis can be int or list[int]
2220 if isinstance(axis, Sequence):
2221 axis = axis[0]
2223 # TODO Rewrite if jax.numpy get a weights-parameter.
2224 if weights is None:
2225 return jnp.sum(xs, axis=axis, keepdims=keepdims, where=mask)
2226 weights = jnp.asarray(weights)
2227 if mask is not None:
2228 xs = jnp.multiply(xs, mask)
2230 # Element-wise variant
2231 if xs.shape == weights.shape and axis is None:
2232 out = jnp.dot(xs.flatten(), weights.flatten())
2233 return (
2234 out if not keepdims else jnp.reshape(out, (1,) * jnp.ndim(xs))
2235 )
2237 # Aggregate along `axis` with 0-dim weights.
2239 if len(weights.shape) == 1 and axis is not None: 2239 ↛ 2243line 2239 didn't jump to line 2243 because the condition on line 2239 was always true
2240 out = jnp.tensordot(xs, weights, axes=[[axis], [0]])
2241 return out if not keepdims else jnp.expand_dims(out, axis)
2242 # None of the above cases.
2243 raise ValueError(
2244 "weights must either be None, or the same shape "
2245 "as `xs` or an 1D-array with length "
2246 "`xs.shape[axis]`."
2247 )