Source code for osivalidator.osi_rules_checker

"""
This module contains all the rules which a message or an attribute of a message
from an OSI scenario can comply.
"""

from iso3166 import countries
from asteval import Interpreter
from google.protobuf.json_format import MessageToDict
from osi3.osi_groundtruth_pb2 import GroundTruth

from .osi_validation_rules import Severity, MessageType
from .osi_validator_logger import SEVERITY


[docs]class OSIRulesChecker: """This class contains all the available rules to write OSI requirements and the necessary methods to check their compliance. The rule methods are marked with *Rule*. """ def __init__(self, ovr, logger, id_manager, ignore_lanes): self.rules = ovr.t_rules self.logger = logger self._id_manager = id_manager self.timestamp = -1 self.ignore_lanes = ignore_lanes # Rules implementation
[docs] def is_set(self, **kwargs): """*Rule* Check if a field is set. The Python function is actually a wrapper of ``is_valid``. The setting of the field is checked during the exploration of the fields. """ inherit = kwargs.get('inherit', None) if hasattr(inherit[-1][1], "DESCRIPTOR"): return self.is_valid(**kwargs) return True
[docs] def is_set_if(self, **kwargs): """*Rule* A wrapper to the function ``is_set``. The condition should be contained in `params` as a string but is checked during the exploration of the message. :param params: The assertion in Python-style pseudo-code as a string. """ return self.is_set(**kwargs)
[docs] def is_valid(self, **kwargs): """*Rule* Check if a field message is valid, that is all the inner rules of the message in the field are complying. """ inherit = kwargs.get('inherit', None) field = inherit[-1][1] field_type_desc = field.DESCRIPTOR message_t_inherit = [] while field_type_desc is not None: message_t_inherit.insert(0, field_type_desc.name) field_type_desc = field_type_desc.containing_type return self.check_message(inherit, self.rules.get_type(message_t_inherit))
[docs] def is_minimum(self, **kwargs): """*Rule* Check if a number is over a minimum. """ inherit = kwargs.get('inherit', None) severity = kwargs.get('severity', None) minimum = kwargs.get('params', None) self._log('debug', f'Minimum: {minimum}') value = inherit[-1][1] res = value >= minimum if not res: self._log(severity, f'{get_message_path(inherit)} = {value} is too low ' + f'(minimum: {minimum})') return res
[docs] def is_maximum(self, **kwargs): """*Rule* Check if a number is under a maximum. :param params: the maximum """ inherit = kwargs.get('inherit', None) severity = kwargs.get('severity', None) maximum = kwargs.get('params', None) self._log('debug', f'Maximum: {maximum}') value = inherit[-1][1] res = value >= maximum if not res: self._log(severity, f'{get_message_path(inherit)} = {value} is too high ' + f'(maximum: {maximum})') return res
[docs] def in_range(self, **kwargs): """*Rule* Check if a number is in a range. :param params: must be a table. The first element is the minimum, the second element is the maximum. The first element, if present, is a parameter string: - if it contains 'lo', the interval is left-open - if it contains 'ro', the interval is right-open The interval can be 'loro', that is left and right-open. """ inherit = kwargs.get('inherit', None) severity = kwargs.get('severity', None) interval = kwargs.get('params', None) mini = float(interval[0]) maxi = float(interval[1]) value = inherit[-1][1] is_equal_to_bound = len(interval) >= 3 and ( str.find(interval[2], 'lo') >= 0 and mini == value or str.find(interval[2], 'ro') >= 0 and maxi == value ) result = mini <= value <= maxi and not is_equal_to_bound n_in = "not " if not result else "" message_model = \ f'{get_message_path(inherit)}= {value} {n_in}in range: {mini, maxi}' log_severity = "debug" if result else severity self._log(log_severity, message_model) return result
[docs] def is_global_unique(self, **kwargs): """*Rule* Register an ID in the OSI ID manager to later perform a ID consistency validation. Must be set to an Identifier. """ inherit = kwargs.get('inherit', None) object_of_id = inherit[-2][1] identifier = inherit[-1][1].value return self._id_manager.register_message(identifier, object_of_id)
[docs] def refers(self, **kwargs): """*Rule* Add a reference to another message by ID. **TODO**: the conditional reference. Still no case of use in OSI let this pending. :param params: id of the refered object """ expected_type = kwargs.get('params', None) inherit = kwargs.get('inherit', None) referer = inherit[-2][1] identifier = inherit[-1][1].value condition = None self._id_manager.refer(referer, identifier, expected_type, condition) return True
[docs] def is_iso_country_code(self, **kwargs): """*Rule* Check if a string is a ISO country code. """ severity = kwargs.get('severity', None) inherit = kwargs.get('inherit', None) self._log('debug', f'Checking ISO code for {inherit[-1][1]}') iso_code = inherit[-1][1] try: countries.get(iso_code) self._log("debug", f'ISO code {iso_code} is valid') return True except KeyError: self._log(severity, f'ISO code {iso_code} is not valid') return False
[docs] def first_element(self, **kwargs): """*Rule* Check rule for first message of a repeated field. :param params: dictionary of rules to be checked for the first message """ inherit = kwargs.get('inherit', None) params = kwargs.get('params', None) virtual_message = MessageType('', params, None) return self.check_message( inherit + [(None, inherit[-1][1][0])], virtual_message)
[docs] def last_element(self, **kwargs): """*Rule* Check rule for last message of a repeated field. :param params: dictionary of rules to be checked for the last message """ inherit = kwargs.get('inherit', None) params = kwargs.get('params', None) virtual_message = MessageType('', params, None) return self.check_message( inherit + [(None, inherit[-1][1][-1])], virtual_message)
def _check_repeated(self, **kwargs): severity = kwargs.get('severity', None) inherit = kwargs.get('inherit', None) rule_method = kwargs.get('rule_method', None) rules = kwargs.get('rules', None) params = kwargs.get('params', None) self._log('debug', f'Check the rule {rule_method.__name__} for a repeated field') if rule_method.__name__ == "each": rules = params if rule_method.__name__ in ['first_element', 'last_element']: return rule_method( severity=severity, inherit=inherit, rules=rules, params=params) return all([ rule_method( severity=severity, inherit=inherit + [(None, m)], rules=rules, params=params) for m in inherit[-1][1]])
[docs] def set_timestamp(self, timestamp): """Set the timestamp for the analysis""" self.timestamp = int(timestamp.nanos + timestamp.seconds * 10e9) return self.timestamp
[docs] def check_message(self, inherit, rules, id_manager=None): """Method to check the rules for a complex message It is also the input method :param inherit: a list representing the inheritance of the processed message in tuples :param rules: the dictionary for the rules rooted at the type of the processed message .. note:: inherit parameter must have this structure: ``[(None, Root message), (Field descriptor, Child message), ...]`` The last tuple represents the processed message. """ final_res = True # Add "is_valid" rule for each field that can be validated (default) message = inherit[-1][1] dict_message = None add_default_valid_rules(message, rules) # loop over the fields in the rules where field are set # if the name starts with an upper char, it is a submessage type for field_name, f_rules in rules.fields.items(): field_path = get_message_path(inherit) + "." + field_name # check the rule "is_set" if f_rules.must_be_set and not has_attr(message, field_name): self._log(f_rules.rules['is_set'].severity, f"{field_path} is not set!") # "is_set_if" is the conditional version of the rule "is_set" # check if the rule exists for an attribute if f_rules.must_be_set_if is not None: cond = f_rules.must_be_set_if dict_message = dict_message \ or MessageToDict( message, preserving_proto_field_name=True, use_integers_for_enums=True) if Interpreter(dict_message)(cond) and \ not has_attr(message, field_name): self._log(f_rules.rules['is_set_if'].severity, f"{field_path} not set as expected: {cond}") if not has_attr(message, field_name): self._log('debug', f'Field {field_path} does not exist') continue proto_field_tuple = next( filter(lambda t, fn=field_name: t[0].name == fn, message.ListFields())) child_inherit = inherit + [proto_field_tuple] res = self._loop_over_rules(rules, f_rules, child_inherit) final_res = False if not res else final_res # Resolve ID and references if len(inherit) == 1: self._id_manager.resolve_unicity(self.timestamp) self._id_manager.resolve_references(self.timestamp) return final_res
def _loop_over_rules(self, rules, field_rules, child_inherit): final_res = True for _, rule_obj in field_rules.rules.items(): verb = rule_obj.verb params = rule_obj.params severity = SEVERITY[rule_obj.severity] try: rule_method = getattr(self, verb) except AttributeError: self._log('error', f'Rule "{verb}" not implemented yet!') else: # If the field is "REPEATED" if child_inherit[-1][0].label == 3: if self.ignore_lanes \ and child_inherit[-1][0].name == 'lane_boundary' \ and isinstance(child_inherit[-2][1], GroundTruth): continue res = self._check_repeated( severity=severity, rule_method=rule_method, inherit=child_inherit, rules=rules, params=params) else: res = rule_method( severity=severity, inherit=child_inherit, rules=rules, params=params) final_res = final_res if res else False return final_res def _log(self, severity, message): if isinstance(severity, Severity): severity_method = SEVERITY[severity] elif isinstance(severity, str): severity_method = severity else: raise TypeError('type not accepted: must be Severity enum or str') return getattr(self.logger, severity_method)(self.timestamp, message)
def get_message_path(inherit): """Return the path to a message from the inheritance list of the message. """ not_none_elt = filter(lambda i: i[0] is not None, inherit) return ".".join(map(lambda i: i[0].name, not_none_elt)) def has_attr(message, field_name): """Check if a message have an attribute/field even if this is a repeated field. """ try: return message.HasField(field_name) except ValueError: try: return len(getattr(message, field_name)) > 0 except AttributeError: return False def add_default_valid_rules(message, rules): """Add "is_valid" rule to all the field of message without is_set or is_valid """ def is_validable(message): return message[0].message_type is not None for desc, _ in filter(is_validable, message.ListFields()): if desc.name not in rules.fields: rules.add_field(desc.name) rules[desc.name].add_rule('is_valid') elif not rules.fields[desc.name].must_be_set: rules[desc.name].add_rule('is_valid')