Customization
While the default functionality of compyre.is_equal and compyre.assert_equal should be good enough for common
use cases, we acknowledge the fact that they cannot cover everything. We designed compyre with this in mind and also
provide a fully customizable API that can be tailored to your needs.
In this tutorial you learn about custom unpacking and equality check functions, how to implement, and how to integrate
them into compyre.
Unpacking functions
Unpacking functions are set through the unpack_fns parameter and are iterated in order until the first non-None
result:
- If the result is list of compyre.api.Pairs, they are added to the beginning of the processing queue resulting in a depth-first traversal of the inputs.
- If the result is an Exception, it is stored and further processing for the pair is skipped.
Suppose you want to compare two numpy.ndarrays.
Simplified, compyre.assert_equal for this use case is equivalent to
import compyre.api
import compyre.builtin
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[],
equal_fns=[compyre.builtin.equal_fns.numpy_ndarray]
)
except AssertionError as e:
print(e)
comparison resulted in 1 error(s):
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2 / 3 (66.7%)
Max absolute difference among violations: 6
Max relative difference among violations: 2.
ACTUAL: array([ 1, 3, -3])
DESIRED: array([1, 2, 3])
Suppose that you rather want to compare the numbers elementwise similar to a list. This can be achieved with a simple unpacking function that uses numpy.ndarray.tolist and compyre.builtin.unpack_fns.collections_sequence to create the pairs with the correct index information.
def unpack_array(p: compyre.api.Pair) -> api.UnpackFnResult:
# both inputs have to be arrays for this function to try to unpack them
if not (isinstance(p.actual, np.ndarray) and isinstance(p.expected, np.ndarray)):
return None
return compyre.builtin.unpack_fns.collections_sequence(
compyre.api.Pair(index=p.index, actual=p.actual.tolist(), expected=p.expected.tolist())
)
Using compyre.builtin.equal_fns.builtins_number as the equality check function, the elementwise equality check can be achieved with
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[unpack_array],
equal_fns=[compyre.builtin.equal_fns.builtins_number]
)
except AssertionError as e:
print(e)
comparison resulted in 2 error(s):
1
AssertionError: Numbers 3 and 2 are not close!
Absolute difference: 1
Relative difference: 0.3333333333333333 (up to 1e-09 allowed)
2
AssertionError: Numbers -3 and 3 are not close!
Absolute difference: 6
Relative difference: 2.0 (up to 1e-09 allowed)
As written, only one-dimensional arrays can be unpacked. For multi-dimensional arrays we get back a nested list of
elements that unpack_array cannot unpack further nor compyre.builtin.equal_fns.builtins_number can compare.
expected = np.array([[1, 2], [3, 4]])
actual = np.array([[1, 2], [4, 4]])
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[unpack_array],
equal_fns=[compyre.builtin.equal_fns.builtins_number]
)
except compyre.api.CompyreError as e:
print(e)
0
CompyreError: unable to compare [1, 2] of type <class 'list'> and [1, 2] of type <class 'list'>
1
CompyreError: unable to compare [4, 4] of type <class 'list'> and [3, 4] of type <class 'list'>
To overcome this, we could refactor unpack_array to return a flattened list of all elements. However, this is rather
cumbersome. An easier solution is to include compyre.builtin.unpack_fns.collections_sequence in the unpacking
functions.
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[
unpack_array,
compyre.builtin.unpack_fns.collections_sequence,
],
equal_fns=[compyre.builtin.equal_fns.builtins_number]
)
except AssertionError as e:
print(e)
comparison resulted in 1 error(s):
1.0
AssertionError: Numbers 4 and 3 are not close!
Absolute difference: 1
Relative difference: 0.25 (up to 1e-09 allowed)
Equipped with this knowledge and under the assumption that
compyre.builtin.unpack_fns.collections_sequence or something equivalent is always present in the unpacking
functions, we can even simplify unpack_array further.
def unpack_array_simple(p: compyre.api.Pair) -> api.UnpackFnResult:
# both inputs have to be arrays for this function to try to unpack them
if not (isinstance(p.actual, np.ndarray) and isinstance(p.expected, np.ndarray)):
return None
return [
compyre.api.Pair(index=p.index, actual=p.actual.tolist(), expected=p.expected.tolist())
]
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[
unpack_array_simple,
compyre.builtin.unpack_fns.collections_sequence,
],
equal_fns=[compyre.builtin.equal_fns.builtins_number]
)
except AssertionError as e:
print(e)
comparison resulted in 1 error(s):
1.0
AssertionError: Numbers 4 and 3 are not close!
Absolute difference: 1
Relative difference: 0.25 (up to 1e-09 allowed)
Equality check functions
Unpacking functions are set through the equal_fns parameter and are iterated in order until the first non-None
result:
- If the result is True, the input pair is considered equal
- If the result is False or any Exception, the input pair is considered not equal. In the former case an AssertionError with a default error message is used.
Suppose you want to compare dictionaries only by their keys and lists only by their length rather than unpacking them and compare their values elementwise.
import dataclasses
@dataclasses.dataclass
class Container:
dct: dict
lst: list
expected = Container(dct={"foo": "foo", "bar": "barbar"}, lst=[0, 1, 2])
actual = Container(dct={"foo": "foo", "baz": "barbar"}, lst=[0, 1])
This can be achieved with two simple equality check functions
import compyre.api
def dict_keys(p: compyre.api.Pair) -> compyre.api.EqualFnResult:
if not (isinstance(p.actual, dict) and isinstance(p.expected, dict)):
return None
if p.actual.keys() == p.expected.keys():
return True
else:
return AssertionError(f"{p.actual.keys()=} != {p.expected.keys()=}")
def list_len(p: compyre.api.Pair) -> compyre.api.EqualFnResult:
if not (isinstance(p.actual, list) and isinstance(p.expected, list)):
return None
return len(p.actual) == len(p.expected)
Using compyre.builtin.unpack_fns.dataclasses_dataclass as the unpacking function, the key and length check can be achieved with
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[compyre.builtin.unpack_fns.dataclasses_dataclass],
equal_fns=[
dict_keys,
list_len,
],
)
except AssertionError as e:
print(e)
comparison resulted in 2 error(s):
dct
AssertionError: p.actual.keys()=dict_keys(['foo', 'baz']) != p.expected.keys()=dict_keys(['foo', 'bar'])
lst
AssertionError: [0, 1] is not equal to [0, 1, 2]
Note that while returning False from list_len instead of an Exception with a custom error message, resulted
in the inequality that we wanted, the default error message is not great. One cannot derive from it that the values are
actually irrelevant.
Parameters
All functions from compyre.builtin.unpack_fns and compyre.builtin.equal_fns as well as all the unpacking and equality check functions in this tutorial so far do not require any parameters. However, that is not a limitation.
Suppose you want to implement an equality function that checks strings either for equality or just their length. You could implement it like
import compyre.api
def string_equal(p: compyre.api.Pair, /, *, len_only: bool) -> compyre.api.EqualFnResult:
if not (isinstance(p.actual, str) and isinstance(p.expected, str)):
return None
if len_only:
return len(p.actual) == len(p.expected)
else:
return p.actual == p.expected
Calling compyre.api.assert_equal as we have before now results in a TypeError as the len_only parameter is not
passed
expected = "foo"
actual = "bar"
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[],
equal_fns=[string_equal],
)
except TypeError as e:
print(e)
We can just pass it as keyword argument to compyre.api.assert_equal
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[],
equal_fns=[string_equal],
len_only=True,
)
To avoid subtle errors, a TypeError is also raised if keyword argument is passed that is not used by any unpacking or equality check function.
try:
compyre.api.assert_equal(
actual,
expected,
unpack_fns=[],
equal_fns=[],
len_only=True,
)
except TypeError as e:
print(e)
Low-level API
If the customisation options detailed so far in this tutorial are still not sufficient for your use case, you can base your logic on the low-level compyre.api.compare. Everything discussed above is still valid, but instead of a boolean check like compyre.api.is_equal or an equality assertion like compyre.api.assert_equal, compyre.api.compare returns the list of Exceptions returned of the unpacking or equality check functions. Thus, you have the option to post-filter, produce custom combined error message, and so on.