from dataclasses import dataclass from typing import Dict, List as List_, Optional, Tuple from ppp_ast import * from ppp_lexer import Lexer from ppp_object import Bool, EnumValue, Function, Int, Object, Str, Struct, Tuple as TupleObject, List as ListObject, TypeObject, Void from ppp_parser import is_valid_target, parse_statement from ppp_tokens import EofToken from ppp_stdlib import variables from ppp_types import EnumType, FunctionType, GenericType, Int as IntType, ListType, Str as StrType, StructType, TupleType, Type, TypeType, VariableType, Void as VoidType @dataclass class Declared: type: Type value: Object @staticmethod def from_obj(obj: Object) -> 'Declared': return Declared(obj.get_type(), obj) @dataclass class Undeclared: type: Type @dataclass class Constant: type: Type value: Object @staticmethod def from_obj(obj: Object) -> 'Declared': return Declared(obj.get_type(), obj) VariableState = Declared | Undeclared | Constant Module = Dict[str, VariableState] @dataclass class ProgramState: modules: Dict[str, Module] # TODO: What is the type of module? contexts: List_[Dict[str, VariableState]] def push_context(self, variables: Dict[str, VariableState]): self.contexts.append(variables) def pop_context(self): self.contexts.pop() def declare_variable(self, name: str, type: Type): assert not (name in self.contexts[-1]), f"'{name}' has already been declared!" self.contexts[-1][name] = Undeclared(type) def assign_variable(self, name: str, value: Object): for context in self.contexts[::-1]: if name in context: assert value.get_type().is_subtype_of(context[name].type), f"In the assignment of '{name}', expected value of type {context[name].type.represent()}, but got a value of type {value.get_type().represent()}!" context[name] = Declared(context[name].type, value) return assert False, f"'{name}' doesn't exist!" def declare_and_assign_variable(self, name: str, value: Object): self.declare_variable(name, value.get_type()) self.assign_variable(name, value) def exists(self, name: str) -> bool: for context in self.contexts[::-1]: if name in context: return True return False def access_variable(self, name: str) -> Object: for context in self.contexts[::-1]: if name in context: value = context[name] assert not isinstance(value, Undeclared), f"{name} is not declared!" return value.value assert False, f"'{name}' is not defined!" def is_truthy(object: Object) -> bool: match object: case Bool(value): return value case _: assert False, ("Unimplemented", object) def calculate_expression(expression: Expression, program: ProgramState) -> Object: match expression: case FunctionCall(function_, arguments_): function = calculate_expression(function_, program) assert isinstance(function, Function), (function_, function) name, parameters, return_type, body, func = function.function arguments = [calculate_expression(argument, program) for argument in arguments_] assert len(arguments) == len(parameters), f"{name} expected {len(parameters)} arguments, but got {len(arguments)}!" for (argument, (parameter_name, parameter)) in zip(arguments, parameters): assert argument.get_type().is_subtype_of(parameter), f"For argument '{parameter_name}' of '{name}', expected value of type {parameter.represent()}, but got {argument.get_type().represent()}!" return_value = func(name, parameters, return_type, body, *arguments) assert isinstance(return_value, Object), return_value assert return_value.get_type().is_subtype_of(return_type) return return_value case Variable(name): return program.access_variable(name) case String(string): return Str(string) case Number(number): return Int(number) case TupleExpr(elements_): tuple_elements = [calculate_expression(element, program) for element in elements_] return TupleObject(TupleType([element.get_type() for element in tuple_elements]), tuple(tuple_elements)) case Ternary(condition_, if_true, if_false): return calculate_expression(if_true, program) if is_truthy(calculate_expression(condition_, program)) else calculate_expression(if_false, program) case Or(lhs, rhs): left_value = calculate_expression(lhs, program) assert isinstance(left_value, Bool) if left_value.value: return Bool(True) right_value = calculate_expression(rhs, program) assert isinstance(right_value, Bool) return Bool(left_value.value or right_value.value) case And(lhs, rhs): left_value = calculate_expression(lhs, program) assert isinstance(left_value, Bool) if not left_value.value: return Bool(False) right_value = calculate_expression(rhs, program) assert isinstance(right_value, Bool) return Bool(left_value.value and right_value.value) case Bor(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case Bxor(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case Band(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case Equal(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) if left_value.get_type() != right_value.get_type(): return Bool(False) return Bool(left_value == right_value) case NotEqual(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) if left_value.get_type() != right_value.get_type(): return Bool(True) return Bool(left_value != right_value) case LessThan(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Bool(left_value.num < right_value.num) case GreaterThan(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Bool(left_value.num > right_value.num) case LessThanOrEqual(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Bool(left_value.num <= right_value.num) case GreaterThanOrEqual(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Bool(left_value.num >= right_value.num) case ShiftLeft(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case ShiftRight(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case Addition(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) if isinstance(left_value, Int): assert isinstance(right_value, Int) return Int(left_value.num + right_value.num) elif isinstance(left_value, Str): assert isinstance(right_value, Str) return Str(left_value.str + right_value.str) elif isinstance(left_value, ListObject): assert isinstance(right_value, ListObject) if left_value.type.type == VariableType(""): return right_value if right_value.type.type == VariableType(""): return left_value assert left_value.type == right_value.type, (left_value, right_value) return ListObject(left_value.type, left_value.list + right_value.list) else: assert False, f"Expected two ints or two strs. Got {left_value.get_type().represent()} and {right_value.get_type().represent()}!" case Subtract(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Int(left_value.num - right_value.num) case Multiplication(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) assert isinstance(left_value, Int) assert isinstance(right_value, Int) return Int(left_value.num * right_value.num) case Division(lhs, rhs): assert False, ("Unimplemented", lhs, rhs) case Modulo(lhs, rhs): left_value = calculate_expression(lhs, program) right_value = calculate_expression(rhs, program) if isinstance(left_value, Int): assert isinstance(right_value, Int), f"Expected int, got {right_value.get_type().represent()}!" return Int(left_value.num % right_value.num) elif isinstance(left_value, Str): # TODO: Maybe actually just implement C-style string formatting? This code is a mess match right_value: case TupleObject(_, tuple_obj_elements): assert left_value.str.count("%"+"s") == len(tuple_obj_elements), (left_value.str.count("%%s"), len(tuple_obj_elements)) the_elements: Tuple[str, ...] = () for element in tuple_obj_elements: assert isinstance(element, Str) the_elements += (element.str,) return Str(left_value.str % the_elements) case Str(string): return Str(left_value.str % string) case _: assert False, f"Format string expected either a string or a tuple of strings, but got a '{right_value.get_type().represent()}'!\n{lhs, rhs, right_value}" match right_value: case Int(num): return Str(left_value.str % num) case _: assert False, ("Unimplemented", right_value) assert False, ("Unimplemented", lhs, rhs) case StructInstantiation(struct_, arguments_): struct = calculate_expression(struct_, program) assert isinstance(struct, TypeObject) assert isinstance(struct.type, StructType) struct_arguments = {name: calculate_expression(expression, program) for (name, expression) in arguments_} for field in struct_arguments: assert field in struct.type.members, f"The struct {struct.type.name} does not have the field '{field}'!" assert struct_arguments[field].get_type().is_subtype_of(struct.type.members[field]), f"'{struct.type.name}.{field}' field expected value of type {struct.type.members[field].represent()}, but got a value of type {struct_arguments[field].get_type().represent()}!" for field in struct.type.members: assert field in struct_arguments, f"Missing field '{field}' of type {struct.type.represent()}." return Struct(struct.type, struct_arguments) case FieldAccess(expression_, field): value = calculate_expression(expression_, program) match value: case TypeObject(type): match type: case EnumType(name, members, _): assert field in members, f"{type.represent()} does not contain the member '{field}'!" member = members[field] if not member: return EnumValue(type, field, []) def return_member(name: str, parameters: List_[Tuple[str, Type]], return_type: Type, _statement: Statement, *args: Object): assert isinstance(return_type, EnumType) assert len(args) == len(parameters) for (arg, (_, parameter)) in zip(args, parameters): assert arg.get_type().is_subtype_of(parameter) return EnumValue(return_type, name, list(args)) return Function(FunctionType(member, type), (field, [('', member_type) for member_type in member], type, Statements([]), return_member)) case _: assert False, ("Unimplemented", type, field) case Struct(type, fields): assert field in fields, f"Struct '{type.represent()}' does not have the field '{field}'!" return fields[field] case TupleObject(type, tuple_): assert field.isdigit(), field tuple_index = int(field) assert 0 <= tuple_index < len(type.types), f"Index {tuple_index} out of bounds for tuple of length {len(type.types)}" return tuple_[tuple_index] case _: assert False, ("Unimplemented", value, field) case ArrayAccess(array_, index_): array = calculate_expression(array_, program) assert array.get_type().is_indexable(), f"Objects of type {array.get_type().represent()} cannot be indexed!" if isinstance(array, Str): index = calculate_expression(index_, program) assert isinstance(index, Int), f"Index must be '{IntType.represent()}', got '{index.get_type().represent()}'!" assert 0 <= index.num < len(array.str), f"Index out of bounds. Str of length {len(array.str)} accessed at index {index.num}. {array_, index_}" return Str(array.str[index.num]) elif isinstance(array, TupleObject): assert False, f"Cannot use array index for tuple. Use `.` (e.g., `{array_.represent()}.0`) syntax instead" elif isinstance(array, ListObject): array_type = array.type.type index = calculate_expression(index_, program) assert isinstance(index, Int), f"Index must be '{IntType.represent()}', got '{index.get_type().represent()}'!" assert 0 <= index.num < len(array.list), f"Index out of bounds. List of length {len(array.list)} accessed at index {index.num}" element = array.list[index.num] assert element.get_type().is_subtype_of(array_type) return element else: assert False, "Unreachable" case Bnot(expression_): assert False, ("Unimplemented", expression_) case Not(expression_): value = calculate_expression(expression_, program) assert isinstance(value, Bool) return Bool(not value.value) case UnaryPlus(expression_): assert False, ("Unimplemented", expression_) case UnaryMinus (expression_): assert False, ("Unimplemented", expression_) case Array(element_type_, array_): element_type = calculate_type_expression(element_type_, program) array_elements_: List_[Object] = [] for element_ in array_: element = calculate_expression(element_, program) assert element.get_type().is_subtype_of(element_type), (element, element_type) array_elements_.append(element) return ListObject(ListType(element_type), array_elements_) case LoopComprehension(element_type_, body_, variable, array_): element_type = calculate_type_expression(element_type_, program) array = calculate_expression(array_, program) assert array.get_type().is_indexable() if isinstance(array, ListObject): elements: List_[Object] = [] for element in array.list: program.push_context({variable: Declared.from_obj(element)}) elements.append(calculate_expression(body_, program)) program.pop_context() assert elements[-1].get_type().is_subtype_of(element_type) return ListObject(ListType(element_type), elements) else: assert False, ("Unimplemented", array) case _: assert False, ("Unimplemented", expression) assert False def calculate_type_expression(expression: TypeExpression, program: ProgramState, must_resolve:bool=True) -> Type: match expression: case TypeName(name): if not program.exists(name) and not must_resolve: return VariableType(name) type_obj = program.access_variable(name) assert isinstance(type_obj, TypeObject) return type_obj.type case ListTypeExpr(type_): return ListType(calculate_type_expression(type_, program, must_resolve)) case TupleTypeExpr(types_): return TupleType([calculate_type_expression(type, program, must_resolve) for type in types_]) case FunctionTypeExpr(arguments_, return_type_): return FunctionType([calculate_type_expression(argument, program, must_resolve) for argument in arguments_], calculate_type_expression(return_type_, program, must_resolve)) case TypeSpecification(type_, types_): type = calculate_type_expression(type_, program, must_resolve) assert isinstance(type, GenericType) assert len(type.variables) == len(types_) types = [calculate_type_expression(type_, program, must_resolve) for type_ in types_] result_type = type.substitute(types) return result_type case _: assert False, ("Unimplemented", expression) assert False, "Unreachable" def match_enum_expression(enum: Type, value: Object, expression: Expression) -> Optional[Dict[str, Object]]: assert isinstance(enum, EnumType) assert isinstance(value, EnumValue) match expression: case Variable(name): if name.startswith('_'): return {name: value} assert name in enum.members, f"Enum '{enum.represent()}' does not contain the member '{name}'!" assert enum.members[name] == [], f"Enum member '{enum.represent()}.{name}' has {len(enum.members[name])} fields that have not been captured!" if name != value.name or value.values: return None return {} case FunctionCall(function, arguments): assert isinstance(function, Variable) assert function.name in enum.members, f"Enum '{enum.represent()}' does not contain the member '{function.name}'!" if function.name != value.name: return None member = enum.members[function.name] assert isinstance(member, list) # TODO: Report calling a struct enum member with parentheses assert isinstance(value.values, list) # Same as above but like inverse assert len(arguments) == len(member), f"{value.get_type().represent()}.{value.name} expected {len(member)} args, but got {len(arguments)}!" assert len(member) == len(value.values) new_variables: Dict[str, Object] = {} for argument, element in zip(arguments, value.values): assert isinstance(argument, Variable) # TODO new_variables[argument.name] = element return new_variables case _: assert False, ("Unimplemented", expression) assert False, ("Unimplemented", value, expression) def update_types(type: Type, program: ProgramState): assert isinstance(type, EnumType) or isinstance(type, StructType) for context in program.contexts: for variable_ in context: variable = context[variable_] match variable: case Declared(type_, value): if isinstance(variable.value, TypeObject): assert type_ == TypeType assert isinstance(value, TypeObject) value.type.fill({type.name: type}, []) case Undeclared(type_): pass case _: assert False, ("Unimplemented", variable) @dataclass class ReturnResult: value: Object @dataclass class ContinueResult: pass @dataclass class BreakResult: pass @dataclass class NothingResult: pass StatementsResult = ReturnResult | ContinueResult | BreakResult | NothingResult def interpret_statements(statements: List_[Statement], program: ProgramState) -> StatementsResult: for statement in statements: match statement: case ExpressionStatement(expression): calculate_expression(expression, program) case Assignment(lhs, rhs, type_): assert is_valid_target(lhs) match lhs: case Variable(name): value = calculate_expression(rhs, program) if type_: type = calculate_type_expression(type_, program) program.declare_variable(name, type) program.assign_variable(name, value) case FieldAccess(expression_, field): expr = calculate_expression(expression_, program) assert isinstance(expr, Struct) struct_type = expr.get_type() assert isinstance(struct_type, StructType) assert field in struct_type.members, f"Struct '{struct_type.represent()}' does not contain the field '{field}'!" value = calculate_expression(rhs, program) assert value.get_type().is_subtype_of(struct_type.members[field]) expr.fields[field] = value case ArrayAccess(array_, index_): array = calculate_expression(array_, program) index = calculate_expression(index_, program) value = calculate_expression(rhs, program) assert array.get_type().is_indexable(), array match array: case _: assert False, ("Unimplemented", array) case _: assert False, ("Unimplemented", lhs) case IfStatement(condition, body, else_body): if is_truthy(calculate_expression(condition, program)): program.push_context({}) return_value = interpret_statements([body], program) program.pop_context() if not isinstance(return_value, NothingResult): return return_value elif else_body: program.push_context({}) return_value = interpret_statements([else_body], program) program.pop_context() if not isinstance(return_value, NothingResult): return return_value case Statements(statements): # TODO: Proper context and scoping program.push_context({}) return_value = interpret_statements(statements, program) program.pop_context() if not isinstance(return_value, NothingResult): return return_value case FunctionDefinition(name, arguments_, return_type_, body): def run_function(name: str, arguments: List_[Tuple[str, Type]], return_type: Type, body: Statement, *args: Object) -> Object: assert len(args) == len(arguments), f"'{name}' expected {len(arguments)} arguments, but got {len(args)} instead!" new_program = ProgramState(program.modules, program.contexts[:2]) new_program.push_context({}) for (argument, (argument_name, argument_type)) in zip(args, arguments): assert argument.get_type().is_subtype_of(argument_type), f"'{name}' expected argument '{argument_name}' to have a value of type {argument_type.represent()}, but got {argument.get_type().represent()} instead!" new_program.declare_variable(argument_name, argument_type) new_program.assign_variable(argument_name, argument) return_value = interpret_statements([body], new_program) new_program.pop_context() assert len(new_program.contexts) == 2 match return_value: case ReturnResult(value): assert value.get_type().is_subtype_of(return_type), f"'{name}' expected a return value of type {return_type.represent()}, but got {value.get_type().represent()}!" return value case NothingResult(): assert return_type.is_subtype_of(VoidType), f"'{name}' expected a return type of {return_type.represent()} but got nothing!" return Void case _: assert False, ("Unimplemented", return_value) arguments = [(argument.name, calculate_type_expression(argument.type, program)) for argument in arguments_] return_type = calculate_type_expression(return_type_, program) if return_type_ else VoidType function_type = FunctionType([argument[1] for argument in arguments], return_type) object = Function(function_type, (name, arguments, return_type, body, run_function)) program.declare_and_assign_variable(name, object) case EnumDefinition(name, entries): enum_type = EnumType(name, {entry.name: [calculate_type_expression(type, program, False) for type in entry.types] for entry in entries}, []) program.declare_and_assign_variable(name, TypeObject(enum_type)) update_types(enum_type, program) case StructDefinition(name, entries): struct_type = StructType(name, {entry.name: calculate_type_expression(entry.type, program, False) for entry in entries}, []) program.declare_and_assign_variable(name, TypeObject(struct_type)) update_types(struct_type, program) case MatchStatement(value_, cases): value = calculate_expression(value_, program) assert isinstance(value, EnumValue), f"Cannot only match over enums, got {value.get_type().represent()} instead!" assert isinstance(value.type, EnumType) # TODO: Pattern match things besides enums for case in cases: if (new_variables := match_enum_expression(value.type, value, case[0])) is not None: program.push_context({name: Declared.from_obj(new_variables[name]) for name in new_variables}) return_value = interpret_statements([case[1]], program) program.pop_context() if not isinstance(return_value, NothingResult): return return_value break case DoWhileStatement(body, condition_): assert condition_ is None # TODO program.push_context({}) return_value = interpret_statements([body], program) program.pop_context() if not isinstance(return_value, NothingResult): return return_value case WhileStatement(condition_, body): while is_truthy(calculate_expression(condition_, program)): program.push_context({}) return_value = interpret_statements([body], program) program.pop_context() match return_value: case NothingResult(): pass case ContinueResult(): continue case BreakResult(): break case ReturnResult(_): return return_value case _: assert False, ("Unimplemented", return_value) if not isinstance(return_value, NothingResult): return return_value case AssertStatement(condition_, message_): if not is_truthy(calculate_expression(condition_, program)): if message_: message = calculate_expression(message_, program) assert isinstance(message, Str) assert False, message.str assert False, "Assertion failed" case TypeDeclarationStatement(declaration): program.declare_variable(declaration.name, calculate_type_expression(declaration.type, program)) case ForLoop(variable, array_, body): array = calculate_expression(array_, program) if isinstance(array, ListObject): for value in array.list: assert isinstance(value, Object) assert value.get_type().is_subtype_of(array.type.type) program.push_context({variable: Declared.from_obj(value)}) return_value = interpret_statements([body], program) program.pop_context() match return_value: case NothingResult(): pass case ReturnResult(_): return return_value case _: assert False, ("Unimplemented", return_value) case ContinueStatement(): return ContinueResult() case BreakStatement(): return BreakResult() case ReturnStatement(expression=expression): return ReturnResult(calculate_expression(expression, program)) case Import(file): # TODO: Maybe an inclusion system within a preprocessor maybe module = interpret_file(file, program.modules) if file not in program.modules else program.modules[file] program.contexts[0] |= module if file not in program.modules: program.modules[file] = module case TypeDefinition(name, expression_): program.declare_and_assign_variable(name, TypeObject(calculate_type_expression(expression_, program))) case DeferStatement(statement=statement): assert False, "TODO: Defers are not implemented" case _: assert False, ("Unimplemented", statement) return NothingResult() def interpret_file(file_path: str, modules: Dict[str, Module]) -> Module: # print(f"\tParsing {file_path}") lexer = Lexer.from_file(file_path) statements: List_[Statement] = [] while not lexer.check_token(EofToken()): statements.append(parse_statement(lexer)) # print(f"\tInterpreting {file_path}") program = ProgramState(modules, [{variable: Declared.from_obj(variables[variable]) for variable in variables}, {}]) return_value = interpret_statements(statements, program) # print(f"Finished {file_path}") assert len(program.contexts) == 2 match return_value: case NothingResult(): pass case ReturnResult(_): assert False, "Cannot return from outside a function!" case ContinueResult(): assert False, "Cannot continue from outside a loop!" case BreakResult(): assert False, "Cannot break from outside a loop!" case _: assert False, ("Unimplemented", return_value) return program.contexts[1]