from ..interface import Contract, ContractNotRespected, RValue, eval_in_context
from ..syntax import W, Keyword, add_contract, add_keyword
from .types_misc import CheckType
from abc import abstractmethod
import numpy as np
[docs]class ArrayElementsTest(Contract):
[docs] @abstractmethod
def test_elements(self, context, value):
""" Returns either a bool or an array of bool. """
[docs] def check_contract(self, context, value, silent):
result = self.test_elements(context, value)
if np.all(result):
return
result = np.array(result) # for simple bool
resultf = result.flatten()
valuef = value.flatten()
some, = np.nonzero(np.logical_not(resultf))
num = value.size
num_fail = len(some)
perc = 100.0 * num_fail / num
error = ("In this array, %d/%d (%f%%) of elements do not respect "
"the condition %s." % (num_fail, num, perc, self))
some_failures = valuef[some]
MAX_N = 4
if len(some_failures) > MAX_N:
some_failures = some_failures[:MAX_N]
failures = list(some_failures)
N = len(failures)
error += '\nThese are the first %d: %s.' % (N, failures)
raise ContractNotRespected(self, error, value, context)
[docs]class ArrayLogical(ArrayElementsTest):
def __init__(self, glyph, precedence):
self.glyph = glyph
self.precedence = precedence
def __str__(self):
def convert(x):
if isinstance(x, ArrayLogical) and x.precedence < self.precedence:
return '(%s)' % x
else:
return '%s' % x
s = self.glyph.join(convert(x) for x in self.clauses)
return s
[docs]class ArrayOR(ArrayLogical):
def __init__(self, clauses, where=None):
assert isinstance(clauses, list)
assert len(clauses) >= 2
for c in clauses:
assert isinstance(c, ArrayElementsTest)
Contract.__init__(self, where)
ArrayLogical.__init__(self, '|', 1)
self.clauses = clauses
[docs] def test_elements(self, context, value):
assert isinstance(value, np.ndarray)
result = False
for c in self.clauses:
result_c = c.test_elements(context, value)
result = np.logical_or(result_c, result)
return result
def __repr__(self):
s = 'ArrayOR(%r)' % self.clauses
return s
[docs] @staticmethod
def parse_action(string, location, tokens):
l = list(tokens[0])
clauses = [l.pop(0)]
while l:
glyph = l.pop(0) # @UnusedVariable
assert glyph == '|'
operand = l.pop(0)
clauses.append(operand)
where = W(string, location)
return ArrayOR(clauses, where=where)
[docs]class ArrayORCustomString(ArrayOR):
def __init__(self, custom_string, **other):
self.custom_string = custom_string
ArrayOR.__init__(self, **other)
def __str__(self):
return self.custom_string
[docs]class ArrayAnd(ArrayLogical):
def __init__(self, clauses, where=None):
assert isinstance(clauses, list)
assert len(clauses) >= 2, clauses
assert isinstance(clauses, list)
assert len(clauses) >= 2
for c in clauses:
assert isinstance(c, ArrayElementsTest)
Contract.__init__(self, where)
ArrayLogical.__init__(self, ',', 2)
self.clauses = clauses
[docs] def test_elements(self, context, value):
assert isinstance(value, np.ndarray)
result = True
for c in self.clauses:
result_c = c.test_elements(context, value)
result = np.logical_and(result_c, result)
return result
def __repr__(self):
s = 'ArrayAnd(%r)' % self.clauses
return s
[docs] @staticmethod
def parse_action(string, location, tokens):
l = list(tokens[0])
clauses = [l.pop(0)]
while l:
glyph = l.pop(0) # @UnusedVariable
assert glyph == ','
operand = l.pop(0)
clauses.append(operand)
where = W(string, location)
return ArrayAnd(clauses, where=where)
[docs]class ArrayConstraint(ArrayElementsTest):
""" Comparisons for numpy array elements. They check that
the condition is respected for all the entries in the array. """
constraints = {
'=': lambda x, rvalue: x == rvalue,
'==': lambda x, rvalue: x == rvalue,
'!=': lambda x, rvalue: x != rvalue,
'>': lambda x, rvalue: x > rvalue,
'>=': lambda x, rvalue: x >= rvalue,
'<': lambda x, rvalue: x < rvalue,
'<=': lambda x, rvalue: x <= rvalue,
}
def __init__(self, glyph, rvalue, where=None):
assert glyph in ArrayConstraint.constraints
assert isinstance(rvalue, RValue)
Contract.__init__(self, where)
self.glyph = glyph
self.rvalue = rvalue
[docs] def test_elements(self, context, value):
""" Returns either a bool or an array of bool. """
assert isinstance(value, np.ndarray)
bound = eval_in_context(context=context, value=self.rvalue,
contract=self)
operation = ArrayConstraint.constraints[self.glyph]
result = operation(value, bound)
return result
def __str__(self):
return '%s%s' % (self.glyph, self.rvalue)
def __repr__(self):
return 'ArrayConstraint(%r,%r)' % (self.glyph, self.rvalue)
[docs] @staticmethod
def parse_action(s, loc, tokens):
where = W(s, loc)
glyph = "".join(tokens['glyph'])
rvalue = tokens['rvalue']
return ArrayConstraint(glyph, rvalue, where)
[docs]class DType(ArrayElementsTest):
""" Checks that the value is an array with the given dtype. """
def __init__(self, dtype, dtype_string=None, where=None):
assert isinstance(dtype, np.dtype)
Contract.__init__(self, where)
self.dtype = dtype
if dtype_string is None:
dtype_string = "%s" % dtype
self.dtype_string = dtype_string
[docs] def test_elements(self, context, value): # @UnusedVariable
assert isinstance(value, np.ndarray) # Guaranteed by construction
return (value.dtype == self.dtype)
def __str__(self):
return self.dtype_string
def __repr__(self):
if "%s" % self.dtype == self.dtype_string:
return 'DType(%r)' % self.dtype
else:
return 'DType(%r,%r)' % (self.dtype, self.dtype_string)
[docs] @staticmethod
def parse_action(dtype=None):
assert dtype is None or isinstance(dtype, np.dtype)
def parse(s, loc, tokens):
where = W(s, loc)
dtype_string = tokens[0]
if dtype is None:
use_dtype = np.dtype(dtype_string)
else:
use_dtype = dtype
return DType(use_dtype, dtype_string, where)
return parse
np_types = {
'np_int': np.int, # Platform integer (normally either int32 or int64)
'np_int8': np.int8, # Byte (-128 to 127)
'np_int16': np.int16, # Integer (-32768 to 32767)
'np_int32': np.int32, # Integer (-2147483648 to 2147483647)
'np_int64': np.int64, # Integer (9223372036854775808 to 9223372036854775807)
'np_uint8': np.uint8, # Unsigned integer (0 to 255)
'np_uint16': np.uint16, # Unsigned integer (0 to 65535)
'np_uint32': np.uint32, # Unsigned integer (0 to 4294967295)
'np_uint64': np.uint64, # Unsigned integer (0 to 18446744073709551615)
'np_float': np.float, # Shorthand for float64.
'np_float16': np.float16, # Half precision float: sign bit, 5 bits exponent, 10 bits mantissa
'np_float32': np.float32, # Single precision float: sign bit, 8 bits exponent, 23 bits mantissa
'np_float64': np.float64, # Double precision float: sign bit, 11 bits exponent, 52 bits mantissa
'np_complex': np.complex, # Shorthand for complex128.
'np_complex64': np.complex64, # Complex number, represented by two 32-bit floats (real and imaginary components)
'np_complex128': np.complex128}
for k, t in np_types.items():
add_contract(Keyword(k).setParseAction(CheckType.parse_action(t)))
add_keyword(k)