python-plus-plus/ppp_interpreter.py

572 lines
26 KiB
Python

from dataclasses import dataclass
from typing import Dict, List as List_, Optional, Tuple
from ppp_ast import *
from ppp_lexer import Lexer
from ppp_object import Bool, EnumValue, Function, Int, Object, Str, Struct, Tuple as TupleObject, List as ListObject, TypeObject, Void
from ppp_parser import is_valid_target, parse_statement
from ppp_tokens import EofToken
from ppp_stdlib import variables
from ppp_types import EnumType, FunctionType, GenericType, Int as IntType, ListType, Str as StrType, StructType, TupleType, Type, TypeType, VariableType, Void as VoidType
@dataclass
class Declared:
type: Type
value: Object
@staticmethod
def from_obj(obj: Object) -> 'Declared':
return Declared(obj.get_type(), obj)
@dataclass
class Undeclared:
type: Type
@dataclass
class Constant:
type: Type
value: Object
@staticmethod
def from_obj(obj: Object) -> 'Declared':
return Declared(obj.get_type(), obj)
VariableState = Declared | Undeclared | Constant
Module = Dict[str, VariableState]
@dataclass
class ProgramState:
modules: Dict[str, Module] # TODO: What is the type of module?
contexts: List_[Dict[str, VariableState]]
def push_context(self, variables: Dict[str, VariableState]): self.contexts.append(variables)
def pop_context(self): self.contexts.pop()
def declare_variable(self, name: str, type: Type):
assert not (name in self.contexts[-1]), f"'{name}' has already been declared!"
self.contexts[-1][name] = Undeclared(type)
def assign_variable(self, name: str, value: Object):
for context in self.contexts[::-1]:
if name in context:
assert value.get_type().is_subtype_of(context[name].type), f"In the assignment of '{name}', expected value of type {context[name].type.represent()}, but got a value of type {value.get_type().represent()}!"
context[name] = Declared(context[name].type, value)
return
assert False, f"'{name}' doesn't exist!"
def declare_and_assign_variable(self, name: str, value: Object):
self.declare_variable(name, value.get_type())
self.assign_variable(name, value)
def exists(self, name: str) -> bool:
for context in self.contexts[::-1]:
if name in context: return True
return False
def access_variable(self, name: str) -> Object:
for context in self.contexts[::-1]:
if name in context:
value = context[name]
assert not isinstance(value, Undeclared), f"{name} is not declared!"
return value.value
assert False, f"'{name}' is not defined!"
def is_truthy(object: Object) -> bool:
match object:
case Bool(value): return value
case _: assert False, ("Unimplemented", object)
def calculate_expression(expression: Expression, program: ProgramState) -> Object:
match expression:
case FunctionCall(function_, arguments_):
function = calculate_expression(function_, program)
assert isinstance(function, Function), (function_, function)
name, parameters, return_type, body, func = function.function
arguments = [calculate_expression(argument, program) for argument in arguments_]
assert len(arguments) == len(parameters), f"{name} expected {len(parameters)} arguments, but got {len(arguments)}!"
for (argument, (parameter_name, parameter)) in zip(arguments, parameters):
assert argument.get_type().is_subtype_of(parameter), f"For argument '{parameter_name}' of '{name}', expected value of type {parameter.represent()}, but got {argument.get_type().represent()}!"
return_value = func(name, parameters, return_type, body, *arguments)
assert isinstance(return_value, Object), return_value
assert return_value.get_type().is_subtype_of(return_type)
return return_value
case Variable(name):
return program.access_variable(name)
case String(string):
return Str(string)
case Number(number):
return Int(number)
case TupleExpr(elements_):
tuple_elements = [calculate_expression(element, program) for element in elements_]
return TupleObject(TupleType([element.get_type() for element in tuple_elements]), tuple(tuple_elements))
case Ternary(condition_, if_true, if_false):
return calculate_expression(if_true, program) if is_truthy(calculate_expression(condition_, program)) else calculate_expression(if_false, program)
case Or(lhs, rhs):
left_value = calculate_expression(lhs, program)
assert isinstance(left_value, Bool)
if left_value.value: return Bool(True)
right_value = calculate_expression(rhs, program)
assert isinstance(right_value, Bool)
return Bool(left_value.value or right_value.value)
case And(lhs, rhs):
left_value = calculate_expression(lhs, program)
assert isinstance(left_value, Bool)
if not left_value.value: return Bool(False)
right_value = calculate_expression(rhs, program)
assert isinstance(right_value, Bool)
return Bool(left_value.value and right_value.value)
case Bor(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case Bxor(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case Band(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case Equal(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
if left_value.get_type() != right_value.get_type(): return Bool(False)
return Bool(left_value == right_value)
case NotEqual(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
if left_value.get_type() != right_value.get_type(): return Bool(True)
return Bool(left_value != right_value)
case LessThan(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Bool(left_value.num < right_value.num)
case GreaterThan(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Bool(left_value.num > right_value.num)
case LessThanOrEqual(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Bool(left_value.num <= right_value.num)
case GreaterThanOrEqual(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Bool(left_value.num >= right_value.num)
case ShiftLeft(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case ShiftRight(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case Addition(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
if isinstance(left_value, Int):
assert isinstance(right_value, Int)
return Int(left_value.num + right_value.num)
elif isinstance(left_value, Str):
assert isinstance(right_value, Str)
return Str(left_value.str + right_value.str)
elif isinstance(left_value, ListObject):
assert isinstance(right_value, ListObject)
if left_value.type.type == VariableType(""): return right_value
if right_value.type.type == VariableType(""): return left_value
assert left_value.type == right_value.type, (left_value, right_value)
return ListObject(left_value.type, left_value.list + right_value.list)
else:
assert False, f"Expected two ints or two strs. Got {left_value.get_type().represent()} and {right_value.get_type().represent()}!"
case Subtract(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Int(left_value.num - right_value.num)
case Multiplication(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
assert isinstance(left_value, Int)
assert isinstance(right_value, Int)
return Int(left_value.num * right_value.num)
case Division(lhs, rhs):
assert False, ("Unimplemented", lhs, rhs)
case Modulo(lhs, rhs):
left_value = calculate_expression(lhs, program)
right_value = calculate_expression(rhs, program)
if isinstance(left_value, Int):
assert isinstance(right_value, Int), f"Expected int, got {right_value.get_type().represent()}!"
return Int(left_value.num % right_value.num)
elif isinstance(left_value, Str):
# TODO: Maybe actually just implement C-style string formatting? This code is a mess
match right_value:
case TupleObject(_, tuple_obj_elements):
assert left_value.str.count("%"+"s") == len(tuple_obj_elements), (left_value.str.count("%%s"), len(tuple_obj_elements))
the_elements: Tuple[str, ...] = ()
for element in tuple_obj_elements:
assert isinstance(element, Str)
the_elements += (element.str,)
return Str(left_value.str % the_elements)
case Str(string):
return Str(left_value.str % string)
case _:
assert False, f"Format string expected either a string or a tuple of strings, but got a '{right_value.get_type().represent()}'!\n{lhs, rhs, right_value}"
match right_value:
case Int(num): return Str(left_value.str % num)
case _: assert False, ("Unimplemented", right_value)
assert False, ("Unimplemented", lhs, rhs)
case StructInstantiation(struct_, arguments_):
struct = calculate_expression(struct_, program)
assert isinstance(struct, TypeObject)
assert isinstance(struct.type, StructType)
struct_arguments = {name: calculate_expression(expression, program) for (name, expression) in arguments_}
for field in struct_arguments:
assert field in struct.type.members, f"The struct {struct.type.name} does not have the field '{field}'!"
assert struct_arguments[field].get_type().is_subtype_of(struct.type.members[field]), f"'{struct.type.name}.{field}' field expected value of type {struct.type.members[field].represent()}, but got a value of type {struct_arguments[field].get_type().represent()}!"
for field in struct.type.members:
assert field in struct_arguments, f"Missing field '{field}' of type {struct.type.represent()}."
return Struct(struct.type, struct_arguments)
case FieldAccess(expression_, field):
value = calculate_expression(expression_, program)
match value:
case TypeObject(type):
match type:
case EnumType(name, members, _):
assert field in members, f"{type.represent()} does not contain the member '{field}'!"
member = members[field]
if not member: return EnumValue(type, field, [])
def return_member(name: str, parameters: List_[Tuple[str, Type]], return_type: Type, _statement: Statement, *args: Object):
assert isinstance(return_type, EnumType)
assert len(args) == len(parameters)
for (arg, (_, parameter)) in zip(args, parameters):
assert arg.get_type().is_subtype_of(parameter)
return EnumValue(return_type, name, list(args))
return Function(FunctionType(member, type), (field, [('', member_type) for member_type in member], type, Statements([]), return_member))
case _: assert False, ("Unimplemented", type, field)
case Struct(type, fields):
assert field in fields, f"Struct '{type.represent()}' does not have the field '{field}'!"
return fields[field]
case _: assert False, ("Unimplemented", value, field)
case ArrayAccess(array_, index_):
array = calculate_expression(array_, program)
assert array.get_type().is_indexable(), f"Objects of type {array.get_type().represent()} cannot be indexed!"
if isinstance(array, Str):
index = calculate_expression(index_, program)
assert isinstance(index, Int), f"Index must be '{IntType.represent()}', got '{index.get_type().represent()}'!"
assert 0 <= index.num < len(array.str), f"Index out of bounds. Str of length {len(array.str)} accessed at index {index.num}. {array_, index_}"
return Str(array.str[index.num])
elif isinstance(array, TupleObject):
index = calculate_expression(index_, program)
assert isinstance(index, Int), f"Index must be '{IntType.represent()}', got '{index.get_type().represent()}'!"
assert 0 <= index.num <= len(array.tuple), f"Index out of bounds. Tuple of length {len(array.tuple)} accessed at index {index.num}. {array_, index_}"
return array.tuple[index.num]
elif isinstance(array, ListObject):
array_type = array.type.type
index = calculate_expression(index_, program)
assert isinstance(index, Int), f"Index must be '{IntType.represent()}', got '{index.get_type().represent()}'!"
assert 0 <= index.num < len(array.list), f"Index out of bounds. List of length {len(array.list)} accessed at index {index.num}"
element = array.list[index.num]
assert element.get_type().is_subtype_of(array_type)
return element
else:
assert False, "Unreachable"
case Bnot(expression_):
assert False, ("Unimplemented", expression_)
case Not(expression_):
value = calculate_expression(expression_, program)
assert isinstance(value, Bool)
return Bool(not value.value)
case UnaryPlus(expression_):
assert False, ("Unimplemented", expression_)
case UnaryMinus (expression_):
assert False, ("Unimplemented", expression_)
case Array(element_type_, array_):
element_type = calculate_type_expression(element_type_, program)
array_elements_: List_[Object] = []
for element_ in array_:
element = calculate_expression(element_, program)
assert element.get_type().is_subtype_of(element_type), (element, element_type)
array_elements_.append(element)
return ListObject(ListType(element_type), array_elements_)
case LoopComprehension(element_type_, body_, variable, array_):
element_type = calculate_type_expression(element_type_, program)
array = calculate_expression(array_, program)
assert array.get_type().is_indexable()
if isinstance(array, ListObject):
elements: List_[Object] = []
for element in array.list:
program.push_context({variable: Declared.from_obj(element)})
elements.append(calculate_expression(body_, program))
program.pop_context()
assert elements[-1].get_type().is_subtype_of(element_type)
return ListObject(ListType(element_type), elements)
else:
assert False, ("Unimplemented", array)
case _:
assert False, ("Unimplemented", expression)
assert False
def calculate_type_expression(expression: TypeExpression, program: ProgramState, must_resolve:bool=True) -> Type:
match expression:
case TypeName(name):
if not program.exists(name) and not must_resolve: return VariableType(name)
type_obj = program.access_variable(name)
assert isinstance(type_obj, TypeObject)
return type_obj.type
case ListTypeExpr(type_):
return ListType(calculate_type_expression(type_, program, must_resolve))
case TupleTypeExpr(types_):
return TupleType([calculate_type_expression(type, program, must_resolve) for type in types_])
case FunctionTypeExpr(arguments_, return_type_):
return FunctionType([calculate_type_expression(argument, program, must_resolve) for argument in arguments_], calculate_type_expression(return_type_, program, must_resolve))
case TypeSpecification(type_, types_):
type = calculate_type_expression(type_, program, must_resolve)
assert isinstance(type, GenericType)
assert len(type.variables) == len(types_)
types = [calculate_type_expression(type_, program, must_resolve) for type_ in types_]
result_type = type.substitute(types)
return result_type
case _:
assert False, ("Unimplemented", expression)
assert False, "Unreachable"
def match_enum_expression(enum: Type, value: Object, expression: Expression) -> Optional[Dict[str, Object]]:
assert isinstance(enum, EnumType)
assert isinstance(value, EnumValue)
match expression:
case Variable(name):
if name.startswith('_'): return {name: value}
assert name in enum.members, f"Enum '{enum.represent()}' does not contain the member '{name}'!"
assert enum.members[name] == [], f"Enum member '{enum.represent()}.{name}' has {len(enum.members[name])} fields that have not been captured!"
if name != value.name or value.values: return None
return {}
case FunctionCall(function, arguments):
assert isinstance(function, Variable)
assert function.name in enum.members, f"Enum '{enum.represent()}' does not contain the member '{function.name}'!"
if function.name != value.name: return None
member = enum.members[function.name]
assert isinstance(member, list) # TODO: Report calling a struct enum member with parentheses
assert isinstance(value.values, list) # Same as above but like inverse
assert len(arguments) == len(member), f"{value.get_type().represent()}.{value.name} expected {len(member)} args, but got {len(arguments)}!"
assert len(member) == len(value.values)
new_variables: Dict[str, Object] = {}
for argument, element in zip(arguments, value.values):
assert isinstance(argument, Variable) # TODO
new_variables[argument.name] = element
return new_variables
case _:
assert False, ("Unimplemented", expression)
assert False, ("Unimplemented", value, expression)
def update_types(type: Type, program: ProgramState):
assert isinstance(type, EnumType) or isinstance(type, StructType)
for context in program.contexts:
for variable_ in context:
variable = context[variable_]
match variable:
case Declared(type_, value):
if isinstance(variable.value, TypeObject):
assert type_ == TypeType
assert isinstance(value, TypeObject)
value.type.fill({type.name: type}, [])
case Undeclared(type_): pass
case _: assert False, ("Unimplemented", variable)
@dataclass
class ReturnResult:
value: Object
@dataclass
class ContinueResult:
pass
@dataclass
class BreakResult:
pass
@dataclass
class NothingResult:
pass
StatementsResult = ReturnResult | ContinueResult | BreakResult | NothingResult
def interpret_statements(statements: List_[Statement], program: ProgramState) -> StatementsResult:
for statement in statements:
match statement:
case ExpressionStatement(expression):
calculate_expression(expression, program)
case Assignment(lhs, rhs, type_):
assert is_valid_target(lhs)
match lhs:
case Variable(name):
value = calculate_expression(rhs, program)
if type_:
type = calculate_type_expression(type_, program)
program.declare_variable(name, type)
program.assign_variable(name, value)
case FieldAccess(expression_, field):
expr = calculate_expression(expression_, program)
assert isinstance(expr, Struct)
struct_type = expr.get_type()
assert isinstance(struct_type, StructType)
assert field in struct_type.members, f"Struct '{struct_type.represent()}' does not contain the field '{field}'!"
value = calculate_expression(rhs, program)
assert value.get_type().is_subtype_of(struct_type.members[field])
expr.fields[field] = value
case ArrayAccess(array_, index_):
array = calculate_expression(array_, program)
index = calculate_expression(index_, program)
value = calculate_expression(rhs, program)
assert array.get_type().is_indexable(), array
match array:
case _: assert False, ("Unimplemented", array)
case _:
assert False, ("Unimplemented", lhs)
case IfStatement(condition, body, else_body):
if is_truthy(calculate_expression(condition, program)):
program.push_context({})
return_value = interpret_statements([body], program)
program.pop_context()
if not isinstance(return_value, NothingResult): return return_value
elif else_body:
program.push_context({})
return_value = interpret_statements([else_body], program)
program.pop_context()
if not isinstance(return_value, NothingResult): return return_value
case Statements(statements):
# TODO: Proper context and scoping
program.push_context({})
return_value = interpret_statements(statements, program)
program.pop_context()
if not isinstance(return_value, NothingResult): return return_value
case FunctionDefinition(name, arguments_, return_type_, body):
def run_function(name: str, arguments: List_[Tuple[str, Type]], return_type: Type, body: Statement, *args: Object) -> Object:
assert len(args) == len(arguments), f"'{name}' expected {len(arguments)} arguments, but got {len(args)} instead!"
new_program = ProgramState(program.modules, program.contexts[:2])
new_program.push_context({})
for (argument, (argument_name, argument_type)) in zip(args, arguments):
assert argument.get_type().is_subtype_of(argument_type), f"'{name}' expected argument '{argument_name}' to have a value of type {argument_type.represent()}, but got {argument.get_type().represent()} instead!"
new_program.declare_variable(argument_name, argument_type)
new_program.assign_variable(argument_name, argument)
return_value = interpret_statements([body], new_program)
new_program.pop_context()
assert len(new_program.contexts) == 2
match return_value:
case ReturnResult(value):
assert value.get_type().is_subtype_of(return_type), f"'{name}' expected a return value of type {return_type.represent()}, but got {value.get_type().represent()}!"
return value
case NothingResult():
assert return_type.is_subtype_of(VoidType), f"'{name}' expected a return type of {return_type.represent()} but got nothing!"
return Void
case _: assert False, ("Unimplemented", return_value)
arguments = [(argument.name, calculate_type_expression(argument.type, program)) for argument in arguments_]
return_type = calculate_type_expression(return_type_, program) if return_type_ else VoidType
function_type = FunctionType([argument[1] for argument in arguments], return_type)
object = Function(function_type, (name, arguments, return_type, body, run_function))
program.declare_and_assign_variable(name, object)
case EnumDefinition(name, entries):
enum_type = EnumType(name, {entry.name: [calculate_type_expression(type, program, False) for type in entry.types] for entry in entries}, [])
program.declare_and_assign_variable(name, TypeObject(enum_type))
update_types(enum_type, program)
case StructDefinition(name, entries):
struct_type = StructType(name, {entry.name: calculate_type_expression(entry.type, program, False) for entry in entries}, [])
program.declare_and_assign_variable(name, TypeObject(struct_type))
update_types(struct_type, program)
case MatchStatement(value_, cases):
value = calculate_expression(value_, program)
assert isinstance(value, EnumValue), f"Cannot only match over enums, got {value.get_type().represent()} instead!"
assert isinstance(value.type, EnumType) # TODO: Pattern match things besides enums
for case in cases:
if (new_variables := match_enum_expression(value.type, value, case[0])) is not None:
program.push_context({name: Declared.from_obj(new_variables[name]) for name in new_variables})
return_value = interpret_statements([case[1]], program)
program.pop_context()
if not isinstance(return_value, NothingResult): return return_value
break
case DoWhileStatement(body, condition_):
assert condition_ is None # TODO
program.push_context({})
return_value = interpret_statements([body], program)
program.pop_context()
if not isinstance(return_value, NothingResult): return return_value
case WhileStatement(condition_, body):
while is_truthy(calculate_expression(condition_, program)):
program.push_context({})
return_value = interpret_statements([body], program)
program.pop_context()
match return_value:
case NothingResult(): pass
case ContinueResult(): continue
case BreakResult(): break
case ReturnResult(_): return return_value
case _: assert False, ("Unimplemented", return_value)
if not isinstance(return_value, NothingResult): return return_value
case AssertStatement(condition_, message_):
if not is_truthy(calculate_expression(condition_, program)):
if message_:
message = calculate_expression(message_, program)
assert isinstance(message, Str)
assert False, message.str
assert False, "Assertion failed"
case TypeDeclarationStatement(declaration):
program.declare_variable(declaration.name, calculate_type_expression(declaration.type, program))
case ForLoop(variable, array_, body):
array = calculate_expression(array_, program)
if isinstance(array, ListObject):
for value in array.list:
assert isinstance(value, Object)
assert value.get_type().is_subtype_of(array.type.type)
program.push_context({variable: Declared.from_obj(value)})
return_value = interpret_statements([body], program)
program.pop_context()
match return_value:
case NothingResult(): pass
case ReturnResult(_): return return_value
case _: assert False, ("Unimplemented", return_value)
case ContinueStatement(): return ContinueResult()
case BreakStatement(): return BreakResult()
case ReturnStatement(expression=expression):
return ReturnResult(calculate_expression(expression, program))
case Import(file):
# TODO: Maybe an inclusion system within a preprocessor maybe
module = interpret_file(file, program.modules) if file not in program.modules else program.modules[file]
program.contexts[0] |= module
if file not in program.modules:
program.modules[file] = module
case TypeDefinition(name, expression_):
program.declare_and_assign_variable(name, TypeObject(calculate_type_expression(expression_, program)))
case DeferStatement(statement=statement):
assert False, "TODO: Defers are not implemented"
case _:
assert False, ("Unimplemented", statement)
return NothingResult()
def interpret_file(file_path: str, modules: Dict[str, Module]) -> Module:
# print(f"\tParsing {file_path}")
lexer = Lexer.from_file(file_path)
statements: List_[Statement] = []
while not lexer.check_token(EofToken()): statements.append(parse_statement(lexer))
# print(f"\tInterpreting {file_path}")
program = ProgramState(modules, [{variable: Declared.from_obj(variables[variable]) for variable in variables}, {}])
return_value = interpret_statements(statements, program)
# print(f"Finished {file_path}")
assert len(program.contexts) == 2
match return_value:
case NothingResult(): pass
case ReturnResult(_): assert False, "Cannot return from outside a function!"
case ContinueResult(): assert False, "Cannot continue from outside a loop!"
case BreakResult(): assert False, "Cannot break from outside a loop!"
case _: assert False, ("Unimplemented", return_value)
return program.contexts[1]