"""The AaC Check AaC plugin implementation module."""
# NOTE: It is safe to edit this file.
# This file is only initially generated by aac gen-plugin, and it won't be overwritten if the file already exists.
from typing import Callable, Any
from aac.context.language_context import LanguageContext
from aac.context.definition import Definition
from aac.execute.aac_execution_result import (
ExecutionResult,
ExecutionStatus,
ExecutionMessage,
MessageLevel,
)
from aac.context.language_error import LanguageError
plugin_name = "Check AaC"
# we'll need to recurse our way through the schema to check all the constraints
# so we'll create a couple functions to help us navigate the way
[docs]
def check_primitive_constraint(
field: Any,
source_definition: Definition,
value_to_check: Any,
primitive_declaration: str,
defining_primitive,
all_constraints_by_name: (dict[str, Callable]),
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Helper method that runs all the constraints for a given primitive.
Args:
field (Any): The field being checked
source_definition (Definition): Source of the check_me field that we are evaluating
value_to_check (Any): The field value being checked
primitive_declaration (str): The declaration of the primitive
defining_primitive: The defining primitive constraints
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
# Check the value_to_check against the defining_primitive
defining_primitive_instance = defining_primitive
for constraint_assignment in defining_primitive_instance.constraints:
constraint_name = constraint_assignment.name
constraint_args = constraint_assignment.arguments
callback = all_constraints_by_name[constraint_name]
# This location code feels like a hack! Is there a better way?
locations = [
lexeme.location
for lexeme in source_definition.lexemes
if lexeme.value == field.name
]
location = None
if len(locations) > 0:
location = locations[0]
result: ExecutionResult = callback(
value_to_check,
primitive_declaration,
constraint_args,
source_definition.source,
location,
)
if constraint_name not in constraint_results:
constraint_results[constraint_name] = []
constraint_results[constraint_name].append(result)
return constraint_results
def _run_primitive_constraint_list(
check_me: Any,
field: Any,
source_definition: Definition,
field_defining_schema: Definition,
all_constraints_by_name: dict[str, Callable],
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Helper method that runs primitive constraints on list values.
Args:
check_me (Any): The field value being checked
field (Any): The field being checked
source_definition (Definition): Source of the check_me field that we are evaluating
field_defining_schema: (Definition): The definition in which the primitive is declared
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
if type(getattr(check_me, field.name)) != list:
raise LanguageError(
f"Value of '{field.name}' was expected to be list, but was '{type(getattr(check_me, field.name))}'",
source_definition.source.uri
)
for item in getattr(check_me, field.name):
value_to_check = item
if value_to_check is not None:
constraint_results = check_primitive_constraint(
field,
source_definition,
item,
field.type[:-2],
field_defining_schema[0].instance,
all_constraints_by_name,
constraint_results
)
return constraint_results
def _run_primitive_constraint_not_list(
check_me: Any,
field: Any,
source_definition: Definition,
field_defining_schema: Definition,
all_constraints_by_name: (dict[str, Callable]),
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Helper method that runs primitive constraints on non-list values.
Args:
check_me (Any): The field value being checked
field (Any): The field being checked
source_definition (Definition): Source of the check_me field that we are evaluating
field_defining_schema: (Definition): The definition in which the primitive is declared
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
value_to_check = getattr(check_me, field.name)
if value_to_check is not None:
constraint_results = check_primitive_constraint(
field,
source_definition,
value_to_check,
field.type,
field_defining_schema[0].instance,
all_constraints_by_name,
constraint_results
)
return constraint_results
def _collect_schema_constraints(check_against: Definition) -> list:
"""
Collects applicable constraints for the current schema.
Args:
check_against (Definition): The Constraint we are checking against
Returns:
list: A list of collected schema constraints.
"""
schema_constraints = []
context = LanguageContext()
for runner in context.get_plugin_runners():
plugin = runner.plugin_definition.instance
for constraint in plugin.schema_constraints:
if constraint.universal:
schema_constraints.append(
context.create_aac_object(
"SchemaConstraintAssignment",
{"name": constraint.name, "arguments": []},
)
)
if check_against.constraints:
for constraint_assignment in check_against.constraints:
schema_constraints.append(constraint_assignment)
return schema_constraints
def _check_against_defined_schema_constraints(
schema_constraints: list,
source_definition: Definition,
check_me: Any,
check_against: Definition,
all_constraints_by_name: dict[str, Callable],
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Checks a value against constraints defined in the defining schema.
Args:
schema_constraints (list): A list of collected schema constraints to check the value against.
source_definition (Definition): Source of the check_me field that we are evaluating
check_me (Any): The field value being checked
check_against (Any): The schema we are comparing the check_me field against
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
for constraint_assignment in schema_constraints:
constraint_name = constraint_assignment.name
constraint_args = constraint_assignment.arguments
callback = all_constraints_by_name[constraint_name]
result: ExecutionResult = callback(
check_me, source_definition, check_against, constraint_args
)
if constraint_name not in constraint_results:
constraint_results[constraint_name] = []
constraint_results[constraint_name].append(result)
return constraint_results
def _check_field_against_constraint(
source_definition: Definition,
check_me: Any,
check_against: Definition,
all_constraints_by_name: dict[str, Callable],
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Loops through each field in the check_against schema.
Args:
source_definition (Definition): Source of the check_me field that we are evaluating
check_me (Any): The field value being checked
check_against (Definition): The schema we are comparing the check_me field against
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
context = LanguageContext()
for field in check_against.fields:
# only check the field if it is present
if not hasattr(check_me, field.name):
continue
# get the name of the schema that defines the field, special handling for arrays and references
type_name = field.type
is_list = False
# if type name ends with "[]", remove the brackets and set is_list to True
if field.type.endswith("[]"):
type_name = field.type[: -2]
is_list = True
# if type name has parameters in parens, remove them
if type_name.find("(") > -1:
type_name = type_name[: type_name.find("(")]
# get the definition that defines the field
field_defining_schema = context.get_definitions_by_name(type_name)
if len(field_defining_schema) != 1:
# Question: should we convert this to a Constraint Failure?
raise LanguageError(
f"Could not find unique schema definition for field type {field.type} with name {field.name}",
source_definition.source.uri
)
if field_defining_schema[0].get_root_key() == "primitive":
# if the field is a primitive, run the primitive constraints
if is_list:
constraint_results = _run_primitive_constraint_list(check_me, field, source_definition, field_defining_schema, all_constraints_by_name, constraint_results)
else:
constraint_results = _run_primitive_constraint_not_list(check_me, field, source_definition, field_defining_schema, all_constraints_by_name, constraint_results)
else:
# if the field is a schema, run the schema constraints
if is_list:
# if the field is a list, check each item in the list
for item in getattr(check_me, field.name):
constraint_results = check_schema_constraint(
source_definition,
item,
field_defining_schema[0].instance,
all_constraints_by_name,
constraint_results
)
else:
constraint_results = check_schema_constraint(
source_definition,
getattr(check_me, field.name),
field_defining_schema[0].instance,
all_constraints_by_name,
constraint_results
)
return constraint_results
[docs]
def check_schema_constraint(
source_definition: Definition,
check_me: Any,
check_against: Definition,
all_constraints_by_name: dict[str, Callable],
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Helper method that runs all the constraints for a given schema.
Args:
source_definition (Definition): Source of the check_me field that we are evaluating
check_me (Any): The field being checked
check_against (Definition): The schema we are comparing the check_me field against
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
Raises:
LanguageError: If unique schema definition for field type not found for field name
LanguageError: If value of field name was something other than a list
"""
# make sure we've got a schema
context = LanguageContext()
if not context.is_aac_instance(check_against, "aac.lang.Schema"):
return constraint_results
# collect applicable constraints
schema_constraints = _collect_schema_constraints(check_against)
# Check the check_me against constraints in the defining_schema
constraint_results = _check_against_defined_schema_constraints(schema_constraints, source_definition, check_me, check_against, all_constraints_by_name, constraint_results)
# loop through the fields on the check_against schema
return _check_field_against_constraint(source_definition, check_me, check_against, all_constraints_by_name, constraint_results)
[docs]
def check_context_constraint(
context_constraint: Definition,
definitions_to_check: list[Definition],
all_constraints_by_name: dict[str, Callable],
constraint_results: dict[str, list[ExecutionResult]]
) -> dict[str, list[ExecutionResult]]:
"""
Helper method that runs context constraints against a given schema.
Args:
context_constraint (Definition): The constraint being checked against.
definitions_to_check (list[Definition]): A list of definitions to check against the constraint.
all_constraints_by_name (dict[str, Callable]): A dictionary of all constraint names and function calls.
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
Returns:
dict[str, list[ExecutionResult]]: An updated dictionary of constraint results.
"""
context: LanguageContext = LanguageContext()
if context_constraint.name not in [
definition.name for definition in definitions_to_check
]:
if context_constraint.name in all_constraints_by_name.keys():
callback = all_constraints_by_name[context_constraint.name]
result: ExecutionResult = callback(context)
if context_constraint.name not in constraint_results:
constraint_results[context_constraint.name] = []
constraint_results[context_constraint.name].append(result)
return constraint_results
def _check_constraint_results(constraint_results: dict[str, list[ExecutionResult]], verbose: bool, fail_on_warn: bool) -> (ExecutionStatus, list):
"""
Loops through constraint results and checks for any failures.
Args:
constraint_results (dict[str, list[ExecutionResult]]): A dictionary of constraint results.
fail_on_warn (bool): Flag to fail when warnings are discovered
verbose (bool): Flag for verbose mode. When true add success messages as encountered.
Returns:
ExecutionStatus, list: Constraint success status and list of messages.
"""
status = ExecutionStatus.SUCCESS
messages = []
for name, results in constraint_results.items():
for result in results:
if result.is_success():
# if the result is a success, add the messages to the list if we're in verbose mode
# because these should only be info messages
if verbose:
messages.extend(result.messages)
elif result.status_code == ExecutionStatus.CONSTRAINT_WARNING:
# if the result is a warning, add the messages to the list and fail the check if fail_on_warn is true
if fail_on_warn:
status = ExecutionStatus.CONSTRAINT_FAILURE
messages.extend(result.messages)
else:
# Any failure (including a constraint failure) is handled the same way
messages.extend(result.messages)
# don't change the status if already a failure
if status != ExecutionStatus.CONSTRAINT_FAILURE:
status = result.status_code
return status, messages
def _collect_all_constraints_by_name() -> dict[str, Callable]:
"""
Collects all constraints found in Language Context into a dictionary.
Returns:
dict[str, Callable]: A dictionary of all constraints.
"""
context = LanguageContext()
all_constraints_by_name: dict[str, Callable] = {}
for runner in context.get_plugin_runners():
for name, callback in runner.constraint_to_callback.items():
all_constraints_by_name[name] = callback
return all_constraints_by_name
[docs]
def check(aac_file: str, fail_on_warn: bool, verbose: bool) -> ExecutionResult:
"""
Checks relevant constraints for given definition(s). Runs context constraints (global constraints), then runs schema constraints (specifically assigned constraints). Primitive constraints are ran as a part of schema constraints.
Args:
aac_file (str): The AaC file being processed
fail_on_warn (bool): Flag to fail when warnings are discovered
verbose (bool): Flag for verbose mode. When true add success messages as encountered.
Returns:
ExecutionResult: Method result containing: plugin_name ("Check AaC"), "check", status, message
including results from lower level helper methods
Raises:
LanguageError: Passed up LanguageError from get_defining_schema_for_root
"""
constraint_results: dict[str, list[ExecutionResult]] = {}
context: LanguageContext = LanguageContext()
# collect all constraints for easy access
all_constraints_by_name = _collect_all_constraints_by_name()
# FIX ME: This call to parse_and_load can throw LanguageError and ParserError exceptions which are not being handled here
# FIX ME: They should be handled, their messages added to the message list being constructed in this method (complete with source and location info)
# FIX ME: so they can be passed along. This was discovered during the expansion of test_check_aac.py and those tests would need to be updated.
definitions_to_check = context.parse_and_load(aac_file)
# First run all context constraint checks
# Context constraints are "language constraints" and are not tied to a specific schema
# You can think of these as "invariants", so they must always be satisfied
for plugin in context.get_definitions_by_root("plugin"):
# we want to check context constraints, but not the ones that are defined in the aac_file we're checking to avoid gen-plugin circular logic
for context_constraint in plugin.instance.context_constraints:
constraint_results = check_context_constraint(context_constraint, definitions_to_check, all_constraints_by_name, constraint_results)
for check_me in definitions_to_check:
try:
defining_schema = context.get_defining_schema_for_root(check_me.get_root_key())
except LanguageError as e:
raise LanguageError(e.message, check_me.source.uri)
# We now check the schema constraints. The primitive constraints are also called as a part of the schema constraints check.
constraint_results = check_schema_constraint(check_me, check_me.instance, defining_schema.instance, all_constraints_by_name, constraint_results)
# loop through all the constraint results and see if any of them failed
status, messages = _check_constraint_results(constraint_results, verbose, fail_on_warn)
# after going through all the constraint results, if we're still successful, add a success message
if verbose:
for check_me in definitions_to_check:
messages.append(
ExecutionMessage(
f"Check {check_me.source.uri} - {check_me.name} was successful.",
level=MessageLevel.DEBUG,
source=check_me.source.uri,
location=None,
)
)
if status == ExecutionStatus.SUCCESS:
messages.append(
ExecutionMessage(
message="All AaC constraint checks were successful.",
level=MessageLevel.INFO,
source=aac_file,
location=None,
)
)
return ExecutionResult(plugin_name, "check", status, messages)