Coverage for src/ltnjax/core.py: 90%

244 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-26 11:35 +0000

1# SPDX-FileCopyrightText: 2026 German Aerospace Center (DLR) 

2# SPDX-License-Identifier: MIT 

3# 

4from __future__ import annotations 

5 

6from collections.abc import Callable, Sequence 

7import types 

8from typing import Any 

9 

10from flax import nnx 

11 

12# Annotations 

13from jax import Array # should be used for outputs 

14import jax.numpy as jnp 

15from jax.typing import ArrayLike # should be used for inputs 

16import numpy as np 

17 

18import ltnjax as ltn 

19 

20 

21VarLabel = str 

22 

23 

24class LTNObject(nnx.Module): 

25 r"""Class representing a generic LTN object. 

26 

27 An LTN object contains the results of an expression with variables 

28 `free_vars`. 

29 

30 In LTNjax, LTN objects are constants, variables, and outputs of predicates, 

31 formulas, functions, connectives, and quantifiers. 

32 

33 Internally this contains a tensor `value` and an ordered list `free_vars`. 

34 If the expression has $n$ `free_vars` $x_1,\dotsc,x_n$, then 

35 the first $n$ axes belong to these variables. The rest is reserved 

36 for `feature_dims` the result. This can $[,\infty)$ many axes. 

37 For each setting of the `free_vars`, the tensor will contain the 

38 corresponding result. This gets more clear by looking at the example below. 

39 

40 The class extends `nnx.Module` since it may contain values. 

41 

42 Attributes: 

43 value: The result of the expression. If there are $n$ 

44 `free_vars`, then the first $n$ axes belong to these 

45 `free_vars` in the same order. 

46 free_vars: The free variables that are contained in the expression. 

47 `free_vars` is ordered in a way that if we have $n$ 

48 free_vars, the first $n$ axes of `value` belong to these 

49 variables. 

50 

51 Examples: 

52 We define the variables $x \in [1, 0]$ and $y \in [0, 1]$ 

53 and the expression $x * y$. 

54 ```python 

55 >>> import ltn 

56 >>> x = ltn.Variable('x', [1., 0.]) 

57 >>> y = ltn.Variable('y', [0., 1.]) 

58 >>> prod = ltn.Connective(ltn.fuzzy_ops.AndProd(stable=False)) 

59 >>> res = prod(x, y) 

60 ``` 

61 

62 If we choose x to be 1. (the value at index 0) and y to be 1. (the 

63 value at index 1), we get $1 * 1 = 1$. 

64 ```python 

65 >>> print(res.take('x', 0).take('y', 1)) 

66 LTNObject(value=Array(1., dtype=float32), free_vars=[]) 

67 ``` 

68 

69 If we only choose y to be 1 (the value at index 0), we get the 

70 LTN object $x * 1$. The result still depends on which x we chose. 

71 Either $x=0$ and we get $x * 1=0$ or $x=1$ and we 

72 get $x * 1=1$. 

73 ```python 

74 >>> print(res.take('y', 1)) 

75 LTNObject(value=Array(shape=(2,), dtype=dtype('float32')),\ 

76 free_vars=['x']) 

77 >>> print(res.take('y', 1).value) 

78 [1. 0.] 

79 ``` 

80 

81 If we choose no variable, then the LTN object will contain all 

82 possible results: 

83 ```python 

84 >>> print(res) 

85 LTNObject(value=Array(shape=(2, 2), dtype=dtype('float32')),\ 

86 free_vars=['x', 'y']) 

87 >>> print(res.value) 

88 [[0. 1.] 

89 [0. 0.]] 

90 ``` 

91 

92 Note that LTN objects can also be tensors of any rank: 

93 ```python 

94 >>> x = ltn.Variable('x', [[[3.1415926535, 2.7182818284],\ 

95 [1.414213562, .6931471805]], [[1., 0.], [0., 1.]]]) 

96 >>> print(x) 

97 Variable(value=Array(shape=(2, 2, 2), dtype=dtype('float32')),\ 

98 free_vars=['x'], var_label='x') 

99 >>> print(x.value) 

100 [[[3.1415927 2.7182817] 

101 [1.4142135 0.6931472]] 

102 

103 [[1. 0. ] 

104 [0. 1. ]]] 

105 ``` 

106 Here, the first axis belongs to free variable 'x' while the remaining 

107 axes are feature dimensions. 

108 """ 

109 

110 def __init__( 

111 self, 

112 value: ArrayLike, 

113 free_vars: list[VarLabel], 

114 trainable: bool = False, 

115 ) -> None: 

116 """Constructor. 

117 

118 Args: 

119 value: The result of the expression. If there are $n$ 

120 `free_vars`, then the first $n$ axes belong to these 

121 `free_vars` in the same order. 

122 free_vars: List of labels of free variables that are contained in 

123 the expression. `free_vars` is ordered in a way that if we 

124 have $n$ free_vars, the first $n$ axes of `value` 

125 belong to these variables. 

126 trainable: Flag indicating whether the LTN constant is trainable 

127 (embedding) or not. 

128 """ 

129 if not trainable: 

130 self.value = jnp.asarray(value, dtype=jnp.float32) 

131 else: 

132 # We put a MyPy ignore here, because nnx.Param behaves as an 

133 # jnp.Array. 

134 self.value = nnx.Param(jnp.asarray(value, dtype=jnp.float32)) # type: ignore 

135 self.free_vars: list[VarLabel] = free_vars 

136 

137 def __repr__(self) -> str: 

138 """Representation function. 

139 

140 Called by the repr() built-in function to compute the "official" 

141 string representation of an object. 

142 """ 

143 return ( 

144 f"ltn.{self.__class__.__name__}(value={self.value}, " 

145 f"free_vars={self.free_vars})" 

146 ) 

147 

148 def _copy(self) -> LTNObject: 

149 """Copy function. 

150 

151 Copy the LTN object but point to the same tensor, for 

152 gradient tracking. 

153 

154 Returns: 

155 Copy of this LTN object instance. 

156 """ 

157 return LTNObject(self.value, self.free_vars) 

158 

159 def _get_axis_of_free_var(self, free_var: VarLabel) -> int: 

160 """Axis of `free_var`. 

161 

162 Given a free variable `free_var`, returns the axis in attribute 

163 `value` that belongs to this free variable. 

164 

165 Args: 

166 free_var: Label of the free variable, whose axis we want. 

167 

168 Returns: 

169 Axis that belongs to the free variable `free_var`. 

170 

171 Raises: 

172 ValueError: If `free_var` does not occur in `free_vars`. 

173 """ 

174 if free_var not in self.free_vars: 

175 raise ValueError( 

176 f"{free_var} is not a free variable occurring in the LTN " 

177 "object." 

178 ) 

179 return self.free_vars.index(free_var) 

180 

181 def _get_dim_of_free_var(self, free_var: VarLabel) -> int: 

182 """Dimension of `free_var`. 

183 

184 Given a label `free_var`, returns the corresp. dimension. 

185 

186 Args: 

187 free_var: The label of the variable. 

188 

189 Returns: 

190 The dimension that corresponds to variable `free_var`. 

191 

192 Note: 

193 Do not confuse `axis` with `dimension`. The first is the position 

194 in the shape of `value` where the second is the value. 

195 """ 

196 return jnp.shape(self.value)[self._get_axis_of_free_var(free_var)] 

197 

198 def shape(self) -> tuple[int, ...]: 

199 """Returns the shape of the grounding of the LTN object. 

200 

201 Returns: 

202 The shape of the grounding of the LTN object. 

203 """ 

204 return self.value.shape 

205 

206 def take(self, free_var: VarLabel, indices: int | list[int]) -> LTNObject: 

207 r"""Take elements along the axis that corresponds to `free_var`. 

208 

209 Args: 

210 free_var: The variable we want to fix. 

211 indices: This is either an integer or a list of integers from 

212 $\{0, \dotsc, n-1\}$ where $n$ is the dimension 

213 of variable `free_var`. This parameter has the same effect as 

214 `jnp.array(indices)` in `jax.numpy.take`. 

215 

216 Returns: 

217 The LTN object containing the elements along the axis that 

218 corresponds to `free_var`. 

219 

220 Raises: 

221 ValueError: If `indices` is neither int nor list of ints. 

222 """ 

223 indices_arr: Array = jnp.array(indices) 

224 remaining_free_vars = [v for v in self.free_vars if v != free_var] 

225 if jnp.ndim(indices_arr) not in [0, 1]: 

226 raise ValueError("Give a single indice or a list of indices.") 

227 

228 result = self._copy() 

229 result.value = jnp.take( 

230 self.value, indices_arr, axis=self._get_axis_of_free_var(free_var) 

231 ) 

232 result.free_vars = remaining_free_vars 

233 return result 

234 

235 

236class Variable(LTNObject): 

237 r"""Class representing an LTN variable. 

238 

239 A variable $x$ that can take only a <b>finite number</b> 

240 $n$ of tensors of any rank, that means with or without feature 

241 dimensions. 

242 

243 Without feature dimensions: 

244 $x \in \mathcal{R}^n$. 

245 

246 With feature dimensions $d_1 \times \dotsc \times d_m$: 

247 $x \in \mathcal{R}^(n \times d_1 \times \dotsc \times d_m)$. 

248 

249 Attributes: 

250 var_label: The name of the variable. 

251 value: The array describes a batch of individuals; 

252 The first axis describes the number of individuals and 

253 the optional remaining axes are the `feature_dims` of the 

254 individuals. The `feature_dims` of each individual must be equals. 

255 free_vars: The free variables that are contained in the expression. 

256 `free_vars` is ordered in a way that if we have $n$ 

257 free_vars, the first n axes of `value` belong to these variables. 

258 trainable: Flag indicating whether the LTN constant is trainable 

259 (embedding) or not. 

260 

261 Raises: 

262 ValueError: If `var_label` starts with one of the reserved strings 

263 `diag` or `_flat`. 

264 

265 Note: 

266 - The first dimension $n$ of an LTN variable is associated with 

267 the number of individuals in the variable, while the other $d$ 

268 exes are associated with the features of the individuals; 

269 - If the Variable should contain truth values in `[0., 1.]` that are 

270 updated during training, ensure that they will stay in that interval. 

271 This could be done by using sigmoid or tanh operations on the values. 

272 Avoid using `jax.numpy.clip(x, 0., 1.)` during training as this will 

273 yield to gradient issues. 

274 """ 

275 

276 def __init__( 

277 self, var_label: VarLabel, individuals: Any, trainable: bool = False 

278 ) -> None: 

279 """Constructor. 

280 

281 Args: 

282 var_label: The name of the variable. 

283 individuals: An object that is convertible to an array. This 

284 includes JAX arrays, NumPy arrays, Python scalars, Python 

285 collections like lists and tuples, objects with an 

286 ``__array__`` method, and objects supporting the Python buffer 

287 protocol. 

288 

289 The array describes a batch of individuals; 

290 The first axis describes the number of individuals and 

291 the optional remaining axes are the `feature_dims` of the 

292 individuals. The `feature_dims` of each individual must be 

293 equals. 

294 trainable: Flag indicating whether the LTN constant is trainable 

295 (embedding) or not. 

296 """ 

297 # Check inputs 

298 for reserved in ["diag", "_flat"]: 

299 if var_label.startswith(reserved): 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 raise ValueError( 

301 f"Labels starting with {reserved} are reserved." 

302 ) 

303 

304 # This is necessary as the input value could be scalars, lists, 

305 # numpy arrays, etc. 

306 value = jnp.asarray(individuals, dtype=jnp.float32) 

307 

308 # Ensure batch_dims for 0D-arrays 

309 if jnp.ndim(value) == 0: 

310 value = value.reshape(1) 

311 free_vars = [var_label] 

312 super().__init__(value, free_vars=free_vars, trainable=trainable) 

313 self.label = var_label 

314 

315 def __repr__(self) -> str: 

316 """Representation function. 

317 

318 Called by the repr() built-in function to compute the "official" 

319 string representation of an object. 

320 """ 

321 return ( 

322 f"ltn.{self.__class__.__name__}(var_label={self.label}, " 

323 f"value={self.value}, free_vars={self.free_vars})" 

324 ) 

325 

326 

327class Constant(LTNObject): 

328 """The class representing constants. 

329 

330 A constant can be a tensor of any rank. 

331 

332 Attributes: 

333 value: Value of the constant that is a array of an arbitrary rank. 

334 free_vars: The free variables that are contained in the expression. 

335 `free_vars` is ordered in a way that if we have $n$ 

336 `free_vars`, the first $n$ axes of `value` belong to these 

337 variables. 

338 trainable: Flag indicating whether the LTN constant is trainable 

339 (embedding) or not. 

340 """ 

341 

342 def __init__(self, value: Any, trainable: bool = False) -> None: 

343 """Constructor. 

344 

345 Args: 

346 value: An object that is convertible to an array. This includes JAX 

347 arrays, NumPy arrays, Python scalars, Python collections like 

348 lists and tuples, objects with an ``__array__`` method, and 

349 objects supporting the Python buffer protocol. 

350 trainable: Flag indicating whether the LTN constant is trainable 

351 (embedding) or not. 

352 """ 

353 # This is necessary as the input value could be scalars, lists, 

354 # numpy arrays, etc. 

355 value = jnp.asarray(value, dtype=jnp.float32) 

356 

357 free_vars: list[VarLabel] = [] 

358 super().__init__(value, free_vars=free_vars, trainable=trainable) 

359 

360 def __repr__(self) -> str: 

361 """Representation function. 

362 

363 Called by the repr() built-in function to compute the "official" 

364 string representation of an object. 

365 """ 

366 return ( 

367 f"ltn.{self.__class__.__name__}(value={self.value}, " 

368 f"free_vars={self.free_vars})" 

369 ) 

370 

371 

372def _flatten_free_dims( 

373 exprs: list[LTNObject], in_place: bool = False 

374) -> list[LTNObject]: 

375 r"""Flattens `free_dims` of `exprs`. 

376 

377 Flattens the $[0,\infty)$ many free dimensions in `exprs` into 

378 one `_flat_` axis. `free_dims` and the `feature_dims` may be empty. If 

379 `free_dims` is empty, the `_flat_`-axis will be of size $1$. That is, 

380 for each LTN object in `exprs` if `expr.value` has shape (5,10,2,3) and 

381 `expr.free_vars=['x','y']`, we will get an output of shape (50,2,3) with 

382 `expr.free_vars = ['_flat_x_y']`. 

383 

384 Args: 

385 exprs: List of [LTNObject][ltnjax.core.LTNObject]. 

386 in_place: (default=False) Boolean that decides whether we perform the 

387 operation on a new copy or on the same LTN objects. 

388 

389 Returns: 

390 List of the flattened LTN objects. The output shape will be the 

391 `_flat_`-axis plus the $[0,\infty)$ many feature_dims. 

392 

393 Note: 

394 - Do not confuse `free_dims` with `feature_dims`. The `free_dims` are 

395 these that are associated with `free_vars`. 

396 """ 

397 if not in_place: 397 ↛ 399line 397 didn't jump to line 399 because the condition on line 397 was always true

398 exprs = [expr._copy() for expr in exprs] 

399 for expr in exprs: 

400 non_var_shape = expr.value.shape[len(expr.free_vars) :] 

401 if expr.value.size == 0: 

402 # The case that we have an empty axis needs extra treatment 

403 # because: https://github.com/numpy/numpy/issues/18519 . 

404 # We need one dimension for the flattened free_vars and the 

405 # non_var_shape does not change. 

406 expr.value = jnp.reshape( 

407 expr.value, 

408 shape=tuple( 

409 [np.prod(expr.value.shape[: len(expr.free_vars)])] 

410 + list(non_var_shape) 

411 ), 

412 ) 

413 else: 

414 expr.value = jnp.reshape( 

415 expr.value, shape=tuple([-1] + list(non_var_shape)) 

416 ) 

417 expr.free_vars = ["_flat_" + "_".join(expr.free_vars)] 

418 return exprs 

419 

420 

421class Predicate(nnx.Module): 

422 """Class representing an LTN predicate. 

423 

424 An LTN predicate is grounded as a mathematical 

425 function (either pre-defined or learnable) that maps from some n-ary 

426 domain of individuals to a real number in [0,1] (fuzzy), which can be 

427 interpreted as a truth value. 

428 

429 In LTNtorch, the inputs of a predicate are automatically broadcasted 

430 before the computation of the predicate, if necessary. Moreover, the 

431 output is organized in a tensor where each dimension is related to one 

432 variable given in input. 

433 

434 Attributes: 

435 model: A `nnx.Module` that evaluates this function. 

436 

437 Note: 

438 - `model` will be called with a tensor that has a batch dimension and 

439 optionally feature dimensions. The batch dimension is the axis that 

440 results from flattening the axes of the free variables. See 

441 [_flatten_free_dims][ltnjax.core._flatten_free_dims]. 

442 """ 

443 

444 def __init__( 

445 self, model: nnx.Module | None = None, func: Callable | None = None 

446 ) -> None: 

447 """Constructor. 

448 

449 Initializes the LTN predicate in two different ways: 

450 1. if `model` is not None, it initializes the predicate with the given 

451 nnx.Module; 

452 2. if `model` is None, it uses the `func` as a function to define 

453 the LTN predicate. Note that, in this case, the LTN predicate is not 

454 learnable. So, the lambda function has to be used only for simple 

455 predicates. 

456 

457 Args: 

458 model: (default=None) A `nnx.Module` that evaluates this 

459 function. 

460 func: (default=None) A lambda_expression. 

461 

462 Raises: 

463 ValueError: If either both `model` and `func` is given or not 

464 `model` nor `func` is given. 

465 TypeError: If `model` is given and not an `nnx.Module` or 

466 if `func` is given and not an `types.LambdaType` object. 

467 """ 

468 if model is not None and func is not None: 468 ↛ 469line 468 didn't jump to line 469 because the condition on line 468 was never true

469 raise ValueError( 

470 "Both model and func parameters have been " 

471 "specified. Expected only one of the two " 

472 "parameters to be specified." 

473 ) 

474 

475 if model is None and func is None: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true

476 raise ValueError( 

477 "Both model and func parameters have not been " 

478 "specified. Expected one of the two parameters " 

479 "to be specified." 

480 ) 

481 

482 if model is not None: 

483 if not isinstance(model, nnx.Module): 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true

484 raise TypeError( 

485 "Predicate() : argument 'model' (position 1) " 

486 "must be a nnx.Module, not " + str(type(model)) 

487 ) 

488 self.model = model 

489 else: # func is not None 

490 if not isinstance(func, types.LambdaType): 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true

491 raise TypeError( 

492 "Predicate() : argument 'func' (position 2) " 

493 "must be a function, not " + str(type(model)) 

494 ) 

495 self.model = LambdaModel(func) 

496 

497 def __call__(self, *inputs: LTNObject, **kwargs: Any) -> LTNObject: 

498 """Evaluates the `model` of the given `inputs` and `kwargs`. 

499 

500 Args: 

501 inputs: tuple of [LTNObject][ltnjax.core.LTNObject] to apply on 

502 `model`. 

503 kwargs: Further arguments to pass to `model`. 

504 

505 Returns: 

506 `model(inputs)`. 

507 

508 Raises: 

509 TypeError: If `inputs` are not of type 

510 [LTNObject][ltnjax.core.LTNObject]. 

511 """ 

512 # check input 

513 inputs_as_list = list(inputs) 

514 for x in inputs_as_list: 

515 if not isinstance(x, LTNObject): 

516 raise TypeError( 

517 "The input to a LTN Predicate should be " 

518 f"instances of {LTNObject}. Got an instance " 

519 f"of {type(x)} instead." 

520 ) 

521 

522 # forward input 

523 inputs_as_list = _broadcast_exprs(inputs_as_list) 

524 # invariant: flat_inputs has shape: (flat) + feature_dims 

525 flat_inputs = _flatten_free_dims(inputs_as_list) 

526 # invariant: t_outputs has shape: (batch) + model_output 

527 t_outputs = self.model(*_as_arrays(flat_inputs), **kwargs) 

528 

529 # recover shape 

530 free_vars = ( 

531 inputs_as_list[0].free_vars if len(inputs_as_list) > 0 else [] 

532 ) 

533 free_dims = ( 

534 jnp.shape(inputs_as_list[0].value)[: len(free_vars)] 

535 if len(inputs_as_list) > 0 

536 else () 

537 ) 

538 

539 # Case: Predicate 

540 # This line differs in function vs predicate. 

541 t_outputs = jnp.reshape(t_outputs, free_dims) 

542 

543 # ensure that values are float 

544 t_outputs = jnp.astype(t_outputs, jnp.float32) 

545 wff = LTNObject(t_outputs, free_vars) 

546 return wff 

547 

548 

549class Function(nnx.Module): 

550 """Class representing LTN functions. 

551 

552 A function that maps $n$ tensors of any rank to 

553 one single tensor of any rank. 

554 

555 Attributes: 

556 model: A `nnx.Module` that evaluates this function. 

557 

558 Note: 

559 - `model` will be called with a tensor that has a batch dimension and 

560 optionally feature dimensions. The batch dimension is the axis that 

561 results from flattening the axes of the free variables. See 

562 [_flatten_free_dims][ltnjax.core._flatten_free_dims]. 

563 """ 

564 

565 def __init__( 

566 self, 

567 model: nnx.Module | None = None, 

568 func: types.LambdaType | None = None, 

569 ) -> None: 

570 """Constructor. 

571 

572 Initializes the LTN predicate in two different ways: 

573 1. if `model` is not None, it initializes the predicate with the given 

574 nnx.Module; 

575 2. if `model` is None, it uses the `func` as a function to define 

576 the LTN predicate. Note that, in this case, the LTN predicate is not 

577 learnable. So, the lambda function has to be used only for simple 

578 predicates. 

579 

580 Args: 

581 model: (default=None) A `nnx.Module` that evaluates this 

582 function. 

583 func: (default=None) A lambda_expression. 

584 

585 Raises: 

586 ValueError: If either both `model` and `func` is given or not 

587 `model` nor `func` is given. 

588 TypeError: If `model` is given and not an `nnx.Module` or 

589 if `func` is given and not an `types.LambdaType` object. 

590 """ 

591 if model is not None and func is not None: 591 ↛ 592line 591 didn't jump to line 592 because the condition on line 591 was never true

592 raise ValueError( 

593 "Both model and func parameters have been " 

594 "specified. Expected only one of the two " 

595 "parameters to be specified." 

596 ) 

597 

598 if model is None and func is None: 598 ↛ 599line 598 didn't jump to line 599 because the condition on line 598 was never true

599 raise ValueError( 

600 "Both model and func parameters have not been " 

601 "specified. Expected one of the two parameters " 

602 "to be specified." 

603 ) 

604 

605 if model is not None: 

606 if not isinstance(model, nnx.Module): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true

607 raise TypeError( 

608 "Function() : argument 'model' (position 1) " 

609 "must be a nnx.Module, not " + str(type(model)) 

610 ) 

611 self.model = model 

612 else: # func is not None 

613 if not isinstance(func, types.LambdaType): 613 ↛ 614line 613 didn't jump to line 614 because the condition on line 613 was never true

614 raise TypeError( 

615 "Function() : argument 'func' (position 2) " 

616 "must be a function, not " + str(type(model)) 

617 ) 

618 self.model = LambdaModel(func) 

619 

620 def __call__(self, *inputs: LTNObject, **kwargs: Any) -> LTNObject: 

621 """Evaluates the `model` of the given `inputs` and `kwargs`. 

622 

623 Args: 

624 inputs: tuple of [LTNObject][ltnjax.core.LTNObject] to apply on 

625 `model`. 

626 kwargs: Further arguments to pass to `model`. 

627 

628 Returns: 

629 `model(inputs)`. 

630 

631 Raises: 

632 TypeError: If `inputs` are not of type 

633 [LTNObject][ltnjax.core.LTNObject]. 

634 """ 

635 # check input 

636 inputs_as_list = list(inputs) 

637 for x in inputs_as_list: 

638 if not isinstance(x, LTNObject): 

639 raise TypeError( 

640 "The input to a LTN Function should be " 

641 f"instances of {LTNObject}. Got an instance " 

642 f"of {type(x)} instead." 

643 ) 

644 

645 # forward input 

646 inputs_as_list = _broadcast_exprs(inputs_as_list) 

647 # invariant: flat_inputs has shape: (flat) + feature_dims 

648 flat_inputs = _flatten_free_dims(inputs_as_list) 

649 # invariant: t_outputs has shape: (batch) + model_output 

650 t_outputs = self.model(*_as_arrays(flat_inputs), **kwargs) 

651 

652 # recover shape 

653 free_vars = ( 

654 inputs_as_list[0].free_vars if len(inputs_as_list) > 0 else [] 

655 ) 

656 free_dims = ( 

657 jnp.shape(inputs_as_list[0].value)[: len(free_vars)] 

658 if len(inputs_as_list) > 0 

659 else () 

660 ) 

661 

662 # Case: Function 

663 # This line differs in function vs predicate. 

664 t_outputs = jnp.reshape( 

665 t_outputs, tuple(list(free_dims) + list(jnp.shape(t_outputs)[1::])) 

666 ) 

667 

668 # ensure that values are float 

669 t_outputs = jnp.astype(t_outputs, jnp.float32) 

670 wff = LTNObject(t_outputs, free_vars) 

671 return wff 

672 

673 

674class LambdaModel(nnx.Module): 

675 """Simple `nnx.Module` that implements a lambda function. 

676 

677 Attributes: 

678 lambda_operator: Lambda expression. 

679 """ 

680 

681 def __init__(self, lambda_operator: Callable) -> None: 

682 """Constructor. 

683 

684 Args: 

685 lambda_operator: Lambda expression. 

686 """ 

687 super().__init__() 

688 self.lambda_layer = lambda_operator 

689 

690 def __call__(self, *inputs: ArrayLike) -> Array: 

691 """Calls the lambda function with `inputs`. 

692 

693 Args: 

694 inputs: Array. 

695 

696 Returns: 

697 `lambda_layer(inputs)`. 

698 """ 

699 return self.lambda_layer(*inputs) 

700 

701 

702def diag(*variables: Variable) -> list[Variable]: 

703 """Diagonalizes `variables`. 

704 

705 This diagonalizes a list of given [Variable][ltnjax.core.Variable] objects, 

706 i.e. this prepares the variables for the use of 

707 [Quantifier][ltnjax.core.Quantifier]. 

708 

709 Args: 

710 variables: Tuple of [Variable][ltnjax.core.Variable] objects to 

711 diagonalize. 

712 

713 Returns: 

714 Tuple of [Variable][ltnjax.core.Variable] objects, but in diagonalized 

715 form. 

716 

717 Raises: 

718 TypeError: If `variables` are not of type 

719 [Variable][ltnjax.core.Variable]. 

720 ValueError: If a variable in `variables` starts with `diag_`. 

721 """ 

722 variables = list(variables) # type: ignore 

723 # check if a list of LTN variables has been passed 

724 if not all(isinstance(x, Variable) for x in variables): 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true

725 raise TypeError( 

726 "Expected parameter 'vars' to be a tuple of Variable, " 

727 "but got " + str([type(v) for v in variables]) 

728 ) 

729 

730 # check if variables are already diagged. 

731 for var in variables: 

732 if var.free_vars[0].startswith("diag_"): 

733 raise ValueError( 

734 "Trying to diag a variable that is already temporarily " 

735 f"diagged: {var.label}." 

736 ) 

737 

738 diag_label = "diag_" + "_".join([var.label for var in variables]) 

739 for var in variables: 

740 var.free_vars = [diag_label] 

741 return variables # type: ignore 

742 

743 

744def undiag(*variables: Variable) -> list[Variable]: 

745 """Resets the `LTN broadcasting` for the given LTN variables. 

746 

747 In other words, it removes the `diagonal quantification` setting from the 

748 given variables. 

749 

750 Args: 

751 variables: Tuple of [Variable][ltnjax.core.Variable] objects 

752 Tuple of LTN [Variable][ltnjax.core.Variable] objects for which the 

753 `diagonal quantification` setting has to be removed. 

754 

755 Returns: 

756 List of the same LTN [Variable][ltnjax.core.Variable] objects given in 

757 input, with the `diagonal quantification` setting removed. 

758 

759 Raises: 

760 TypeError: If `variables` are not of type 

761 [Variable][ltnjax.core.Variable]. 

762 """ 

763 variables = list(variables) # type: ignore 

764 # check if a list of LTN variables has been passed 

765 if not all(isinstance(x, Variable) for x in variables): 765 ↛ 766line 765 didn't jump to line 766 because the condition on line 765 was never true

766 raise TypeError( 

767 "Expected parameter 'vars' to be a tuple of Variable, " 

768 "but got " + str([type(v) for v in variables]) 

769 ) 

770 

771 for var in variables: 

772 var.free_vars = [var.label] 

773 return variables # type: ignore 

774 

775 

776def _as_arrays(exprs: Sequence[LTNObject]) -> list[Array]: 

777 """Converts LTN objects to arrays. 

778 

779 This function takes a list of [LTNObject][ltnjax.core.LTNObject] objects 

780 and outputs a list of its values. 

781 

782 Args: 

783 exprs: List of [LTNObject][ltnjax.core.LTNObject] objects. 

784 

785 Returns: 

786 List of values of the given [LTNObject][ltnjax.core.LTNObject]. 

787 

788 Note: 

789 We only need to take sequences as they only must read-only. 

790 """ 

791 return [expr.value for expr in exprs] 

792 

793 

794def _broadcast_exprs( 

795 exprs: list[LTNObject], in_place: bool = False 

796) -> list[LTNObject]: 

797 """Broadcasts variables of `exprs`. 

798 

799 This collects the union of free_vars in `exprs` and broadcasts it to 

800 each LTN object in `exprs`, i.e. at the end each free variable in `exprs` 

801 is contained in each LTN object in `exprs` together with its own axis. 

802 

803 Args: 

804 exprs: List of [LTNObject][ltnjax.core.LTNObject] objects thats 

805 variables will be broadcasted. 

806 in_place: (default=False) Boolean that decides whether we perform the 

807 operation on a new copy or on the same LTN objects. 

808 

809 Returns: 

810 List of LTN objects but with broadcasted free variables. 

811 

812 Example: 

813 `exprs[0]` has shape `(5, 2)` and `free_vars = ['x']` and 

814 `exprs[1]` has shape `(10, 2)` and `free_vars = ['y']`. 

815 After broadcasting, both LTN objects will have shape `(5, 10, 2)` and 

816 `free_vars = ['x', 'y']`. 

817 >>> import ltn 

818 >>> import numpy as np 

819 >>> x = ltn.core.LTNObject(np.ones((5,2)), 'x') 

820 >>> y = ltn.core.LTNObject(np.ones((10,2)), 'y') 

821 >>> res = ltn.core._broadcast_exprs([x, y]) 

822 >>> print(res[0]) 

823 LTNObject( 

824 value=Array(shape=(5, 10, 2), dtype=dtype('float32')), 

825 free_vars=['x', 'y'] 

826 ) 

827 >>> print(res[1]) 

828 LTNObject( 

829 value=Array(shape=(5, 10, 2), dtype=dtype('float32')), 

830 free_vars=['x', 'y'] 

831 ) 

832 """ 

833 # measure dimensions for each free variable 

834 free_var_to_dim = {} 

835 for expr in exprs: 

836 for free_var in expr.free_vars: 

837 free_var_to_dim[free_var] = expr._get_dim_of_free_var(free_var) 

838 free_vars = list(free_var_to_dim.keys()) 

839 # broadcast 

840 if not in_place: 840 ↛ 842line 840 didn't jump to line 842 because the condition on line 840 was always true

841 exprs = [expr._copy() for expr in exprs] 

842 for expr in exprs: 

843 free_vars_in_arg = list(expr.free_vars) 

844 free_vars_not_in_arg = list( 

845 set(free_vars).difference(free_vars_in_arg) 

846 ) 

847 for new_free_var in free_vars_not_in_arg: 

848 new_idx = len(free_vars_in_arg) 

849 expr.value = jnp.expand_dims(expr.value, axis=new_idx) 

850 expr.value = jnp.repeat( 

851 expr.value, free_var_to_dim[new_free_var], axis=new_idx 

852 ) 

853 free_vars_in_arg.append(new_free_var) 

854 perm = [ 

855 free_vars_in_arg.index(free_var) for free_var in free_vars 

856 ] + list(range(len(free_vars_in_arg), jnp.ndim(expr.value))) 

857 expr.value = jnp.transpose(expr.value, axes=perm) 

858 expr.free_vars = free_vars 

859 return exprs 

860 

861 

862class Connective: 

863 """Class representing an LTN connective. 

864 

865 Wrapper for connectives that aggregates given 

866 [LTNObject][ltnjax.core.LTNObject] objects according a given aggregator 

867 operation `connective_op` and also broadcasts variables, see 

868 [_broadcast_exprs][ltnjax.core._broadcast_exprs]. 

869 

870 Attributes: 

871 connective_op: Aggregation function. 

872 """ 

873 

874 def __init__( 

875 self, connective_op: ltn.fuzzy_ops.ConnectiveOperator 

876 ) -> None: 

877 """Constructor. 

878 

879 Args: 

880 connective_op: Aggregation function. 

881 """ 

882 self.connective_op = connective_op 

883 

884 def __call__(self, *wffs: LTNObject, **kwargs: Any) -> LTNObject: 

885 """Applies the connective using the given `connective_op`. 

886 

887 Args: 

888 wffs: Tuple of LTN objects. 

889 kwargs: Further arguments to pass to `connective_op`. 

890 

891 Returns: 

892 The resulting [LTNObject][ltnjax.core.LTNObject] object that 

893 combines the given `wffs` into one joint LTN objects. 

894 

895 Raises: 

896 TypeError: If `wffs` are not of type 

897 [LTNObject][ltnjax.core.LTNObject]. 

898 ValueError: If number of `wffs` does not fit to `connective_op`. 

899 """ 

900 wffs = list(wffs) # type: ignore 

901 for x in wffs: 

902 if not isinstance(x, LTNObject): 

903 raise TypeError( 

904 "The operands of a LTN connective should be " 

905 f"instances of {LTNObject}. Got an instance " 

906 f"of {type(x)} instead." 

907 ) 

908 

909 wffs = _broadcast_exprs(wffs) # type: ignore 

910 if isinstance( 

911 self.connective_op, ltn.fuzzy_ops.UnaryConnectiveOperator 

912 ): 

913 if len(wffs) != 1: 913 ↛ 914line 913 didn't jump to line 914 because the condition on line 913 was never true

914 raise ValueError( 

915 "wffs must have length 1 since connective_op " 

916 "is an UnaryConnectiveOperator." 

917 ) 

918 t_result = self.connective_op(*_as_arrays(wffs), **kwargs) 

919 elif isinstance( 

920 self.connective_op, ltn.fuzzy_ops.BinaryConnectiveOperator 

921 ): 

922 if len(wffs) != 2: 922 ↛ 923line 922 didn't jump to line 923 because the condition on line 922 was never true

923 raise ValueError( 

924 "wffs must have length 2 since connective_op " 

925 "is an BinaryConnectiveOperator." 

926 ) 

927 t_result = self.connective_op(*_as_arrays(wffs), **kwargs) 

928 elif isinstance(self.connective_op, ltn.fuzzy_ops.AggregationOperator): 928 ↛ 934line 928 didn't jump to line 934 because the condition on line 928 was always true

929 if len(wffs) < 1: 929 ↛ 930line 929 didn't jump to line 930 because the condition on line 929 was never true

930 raise ValueError("wffs must have length at least 1.") 

931 t_result = self.connective_op( 

932 jnp.stack(_as_arrays(wffs)), axis=0, **kwargs 

933 ) 

934 result = LTNObject(t_result, wffs[0].free_vars) 

935 return result 

936 

937 

938class Quantifier: 

939 r"""Class representing an LTN quantifier. 

940 

941 Wrapper for Quantifiers. This evaluates a given LTN object `wff` for 

942 all variable combinations for that the condition `mask` is true. Then, the 

943 results will be aggregated with the aggregation operator `aggreg_op`. 

944 

945 Attributes: 

946 aggreg_op: Aggregation operator. 

947 quantifier: (str = "f" | "e"`) Decides whether this is a 

948 "forall"- or "exists"-quantifier. This has no effect on the 

949 aggregator but is important for cases, where we aggregate over 

950 $\emptyset$. This may happen if the variables are empty 

951 or the mask masks each variable-combination. In these cases, 

952 "forall" expressions are true while "exists"-quantifiers are 

953 false. 

954 If the Quantifier is used with non-truth values, the quantifier 

955 can be used as the <b>neural elements</b> like for `sum` or 

956 `prod`. 

957 

958 Raises: 

959 TypeError: If `aggreg_op` is not of type 

960 [ConnectiveOperator][ltnjax.fuzzy_ops.ConnectiveOperator]. 

961 ValueError: If `quantifier` is not one of the strings `forall` or 

962 `exists`. 

963 

964 Note: 

965 It is possible that the variable-combinations are empty or that the 

966 condition `mask` will mask every variable-combination. In both cases, 

967 a "forall"-statement will always be true while an "exists"-statement 

968 will always be false. 

969 """ 

970 

971 def __init__( 

972 self, aggreg_op: ltn.fuzzy_ops.ConnectiveOperator, quantifier: str 

973 ) -> None: 

974 r"""Constructor. 

975 

976 Args: 

977 aggreg_op: Aggregation operator. 

978 quantifier: (str = "f" | "e"`) Decides whether this is a 

979 "forall"- or "exists"-quantifier. This has no effect on the 

980 aggregator but is important for cases, where we aggregate over 

981 $\emptyset$. This may happen if the variables are empty 

982 or the mask masks each variable-combination. In these cases, 

983 "forall" expressions are true while "exists"-quantifiers are 

984 false. 

985 If the Quantifier is used with non-truth values, the quantifier 

986 can be used as the <b>neural elements</b> like for `sum` or 

987 `prod`. 

988 """ 

989 self.aggreg_op = aggreg_op 

990 if not isinstance(aggreg_op, ltn.fuzzy_ops.ConnectiveOperator): 

991 raise TypeError( 

992 "The aggregation operator for the quantifier " 

993 "should be an instance of " 

994 f"{ltn.fuzzy_ops.ConnectiveOperator}. Got an " 

995 f"instance of {type(aggreg_op)} instead." 

996 ) 

997 if quantifier not in ["f", "e"]: 

998 raise ValueError( 

999 '`quantifier` for the quantifier should be "f" or "e".' 

1000 ) 

1001 self.quantifier = quantifier 

1002 

1003 def __call__( 

1004 self, 

1005 variables: Variable | list[Variable], 

1006 wff: LTNObject, 

1007 mask: LTNObject | None = None, 

1008 **kwargs: Any, 

1009 ) -> LTNObject: 

1010 """Applies the quantification and outputs the resulting LTN object. 

1011 

1012 As a side-effect, this removes the `diagonal quantification` from the 

1013 given `variables`. 

1014 Refer to [undiag][ltnjax.core.undiag]. 

1015 

1016 Args: 

1017 variables: Variable or list of variables. 

1018 wff: LTN object. 

1019 mask: (default=None) Condition operation. 

1020 kwargs: Further arguments to pass to `connective_op`. 

1021 

1022 Returns: 

1023 The resulting LTN object. 

1024 

1025 Raises: 

1026 TypeError: If the values of `variables` are not instances of 

1027 [Variable][ltnjax.core.Variable] or if `wff` is not of type 

1028 [LTNObject][ltnjax.core.LTNObject]. 

1029 """ 

1030 # check inputs 

1031 variables = ( 

1032 [variables] if not isinstance(variables, list) else variables 

1033 ) 

1034 for x in variables: 

1035 if not isinstance(x, Variable): 

1036 raise TypeError( 

1037 "The quantified variables should be " 

1038 f"instances of {Variable}. Got an instance of " 

1039 "{type(x)} instead." 

1040 ) 

1041 if not isinstance(wff, LTNObject): 1041 ↛ 1042line 1041 didn't jump to line 1042 because the condition on line 1041 was never true

1042 raise TypeError( 

1043 "The quantified LTN object should be an instance " 

1044 f"of {LTNObject}. Got an instance of {type(x)} " 

1045 "instead ." 

1046 ) 

1047 if mask is not None and not isinstance(mask, LTNObject): 

1048 raise TypeError( 

1049 "The mask argument should be an instance of " 

1050 f"{LTNObject}. Got an instance of {type(mask)} " 

1051 "instead." 

1052 ) 

1053 

1054 # Note: For the edge-case that variables are empty. 

1055 # Since the empty variables are already passed into predicates that 

1056 # return size-0 arrays, we have to check wff. 

1057 # This will be done before self.aggreg_op is applied. 

1058 

1059 aggreg_vars = {var.free_vars[0] for var in variables} 

1060 

1061 if mask is not None: 

1062 # This block is for broadcasting variables in wff and the 

1063 # variables in mask. 

1064 # Important to put aggreg dims last, to keep other dims in the 

1065 # ragged result. 

1066 mask = Quantifier._transpose_free_vars( 

1067 mask, 

1068 new_var_order=[ 

1069 var for var in mask.free_vars if var not in aggreg_vars 

1070 ] 

1071 + [var for var in mask.free_vars if var in aggreg_vars], 

1072 ) 

1073 wff = Quantifier._broadcast_wff_and_mask(wff, mask) 

1074 

1075 mask.value = jnp.astype(mask.value, jnp.bool) 

1076 

1077 # Ignore vars in variables that do not occur in wff.free_vars. 

1078 aggreg_axes = [ 

1079 wff.free_vars.index(var) 

1080 for var in aggreg_vars 

1081 if var in wff.free_vars 

1082 ] 

1083 if wff.value.size == 0: # Check edge case, when wff is empty: 

1084 t_result = ( 

1085 jnp.prod(wff.value, axis=aggreg_axes, **kwargs) 

1086 if self.quantifier == "f" 

1087 else jnp.sum(wff.value, axis=aggreg_axes, **kwargs) 

1088 ) 

1089 else: 

1090 t_result = self.aggreg_op( 

1091 wff.value, axis=aggreg_axes, mask=mask.value, **kwargs 

1092 ) 

1093 

1094 aggreg_axes_in_mask = [ 

1095 mask.free_vars.index(var) 

1096 for var in aggreg_vars 

1097 if var in mask.free_vars 

1098 ] 

1099 # empty_vars are the variable-combinations that are completely 

1100 # masked and we need to apply replacement value `rep_value`. 

1101 non_empty_vars = ( 

1102 jnp.sum( 

1103 jnp.astype(mask.value, jnp.int32), axis=aggreg_axes_in_mask 

1104 ) 

1105 != 0 

1106 ) 

1107 rep_value = 1.0 if self.quantifier == "f" else 0 

1108 

1109 t_result = jnp.where(non_empty_vars, t_result, rep_value) 

1110 else: 

1111 # ignore vars in variables that do not occur in wff.free_vars 

1112 aggreg_axes = [ 

1113 wff.free_vars.index(var) 

1114 for var in aggreg_vars 

1115 if var in wff.free_vars 

1116 ] 

1117 if wff.value.size == 0: # Check edge case, when wff is empty: 

1118 t_result = ( 

1119 jnp.prod(wff.value, axis=aggreg_axes, **kwargs) 

1120 if self.quantifier == "f" 

1121 else jnp.sum(wff.value, axis=aggreg_axes, **kwargs) 

1122 ) 

1123 else: 

1124 t_result = self.aggreg_op( 

1125 wff.value, axis=aggreg_axes, **kwargs 

1126 ) 

1127 free_vars_remaining = [ 

1128 var for var in wff.free_vars if var not in aggreg_vars 

1129 ] 

1130 result = LTNObject(t_result, free_vars_remaining) 

1131 undiag(*variables) 

1132 return result 

1133 

1134 @staticmethod 

1135 def _broadcast_wff_and_mask( 

1136 wff: LTNObject, mask: LTNObject, in_place: bool = False 

1137 ) -> LTNObject: 

1138 """Broadcasts the free variables from `mask` to `wff`. 

1139 

1140 The variables of `mask` are put in the first axes. 

1141 

1142 Args: 

1143 wff: LTN object. 

1144 mask: LTN object. 

1145 in_place: (default=False) Boolean that decides whether we perform 

1146 the operation on `wff` or a new copy. 

1147 

1148 Returns: 

1149 The LTN object `wff` with the vars from `mask` added. 

1150 """ 

1151 if not in_place: 1151 ↛ 1155line 1151 didn't jump to line 1155 because the condition on line 1151 was always true

1152 wff = wff._copy() 

1153 # 1. Broadcast wff with vars that are in the mask but not yet in the 

1154 # LTN object. 

1155 mask_vars_not_in_wff = [ 

1156 var for var in mask.free_vars if var not in wff.free_vars 

1157 ] 

1158 for var in mask_vars_not_in_wff: 

1159 new_idx = len(wff.free_vars) 

1160 wff.value = jnp.expand_dims(wff.value, axis=new_idx) 

1161 wff.value = jnp.repeat( 

1162 wff.value, mask._get_dim_of_free_var(var), axis=new_idx 

1163 ) 

1164 wff.free_vars.append(var) 

1165 # 2. Transpose wff so that the masked vars on the first axes. 

1166 vars_not_in_mask = [ 

1167 var for var in wff.free_vars if var not in mask.free_vars 

1168 ] 

1169 wff = Quantifier._transpose_free_vars( 

1170 wff, new_var_order=mask.free_vars + vars_not_in_mask 

1171 ) 

1172 return wff 

1173 

1174 @staticmethod 

1175 def _transpose_free_vars( 

1176 expr: LTNObject, new_var_order: list[VarLabel], in_place: bool = False 

1177 ) -> LTNObject: 

1178 """Transposes free variables. 

1179 

1180 This changes the order of variables in `expr.free_vars` and the 

1181 axes of `expr.value` will be transposed accordingly. 

1182 

1183 Args: 

1184 expr: The LTN object whose variables will be transposed. 

1185 new_var_order: List of variables that defines the new order. 

1186 in_place: (default=False) Boolean that decides whether we perform 

1187 the operation on a new copy or on the same LTN objects. 

1188 

1189 Returns: 

1190 The transposed LTN object. 

1191 """ 

1192 permutation = [expr.free_vars.index(var) for var in new_var_order] 

1193 if not in_place: 1193 ↛ 1195line 1193 didn't jump to line 1195 because the condition on line 1193 was always true

1194 expr = expr._copy() 

1195 expr.value = jnp.transpose(expr.value, permutation) 

1196 expr.free_vars = new_var_order 

1197 return expr