from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple import sys sys.setrecursionlimit(1000) class Type(ABC): def is_indexable(self) -> bool: return False def is_subtype_of(self, other: 'Type') -> bool: match self, other: case StrType(), StrType(): return True case TypeType_(), TypeType_(): return True case FunctionType(self_arguments, self_return_type), FunctionType(other_arguments, other_return_type): assert len(self_arguments) == len(other_arguments) for (self_argument, other_argument) in zip(self_arguments, other_arguments): if not other_argument.is_subtype_of(self_argument): return False return self_return_type.is_subtype_of(other_return_type) case EnumType(self_name, self_members), EnumType(other_name, other_members): # if self_name == other_name: assert self is other, (num_expressions, self, other, self_name, other_name, self_members, other_members) return self is other return self_name == other_name case StructType(_, _), StructType(_, _): return self is other case VoidType(), VoidType(): return True case ListType(self_type), ListType(other_type): # TODO: Maybe return which types match if isinstance(self_type, VariableType): return True if isinstance(other_type, VariableType): return True return self_type == other_type return self_type.is_subtype_of(other_type) case TupleType(self_elememts), TupleType(other_elements): if len(self_elememts) != len(other_elements): return False for (self_element, other_element) in zip(self_elememts, other_elements): if not self_element.is_subtype_of(other_element): return False return True case IntType(), IntType(): return True case VariableType(self_name), VariableType(other_name): return self_name == other_name case _, VariableType(""): return True case BoolType(), BoolType(): return True case type, ObjectType(): return True case type_a, type_b if type_a.__class__ != type_b.__class__: return False case _, _: assert False, ("Unimplemented", self, other) assert False, ("Unimplemented", self, other) def __eq__(self, other): return isinstance(other, Type) and self.is_subtype_of(other) and other.is_subtype_of(self) @abstractmethod def represent(self) -> str: ... @abstractmethod def fill(self, types: 'Dict[str, Type]', stack: List[int]) -> 'Type': ... @abstractmethod def new_fill(self, types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, Type]': ... def new_fill_list(self, type_list: 'List[Type]', types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, List[Type]]': new_types = [type.new_fill(types, stack+[id(self)]) for type in type_list] is_new = any([new_type[0] for new_type in new_types]) return (is_new, [new_type[1] for new_type in new_types]) def new_fill_dict(self, type_dict: 'Dict[str, Type]', types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, Dict[str, Type]]': new_types = {field: type_dict[field].new_fill(types, stack+[id(self)]) for field in type_dict} is_new = any([new_types[field][0] for field in new_types]) return (is_new, {field: new_types[field][1] for field in new_types}) class Primitive(Type): def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: return (False, self) class IntType(Primitive): def represent(self) -> str: return 'int' Int = IntType() class StrType(Primitive): def is_indexable(self) -> bool: return True def represent(self) -> str: return 'str' Str = StrType() class BoolType(Primitive): def represent(self) -> str: return 'bool' Bool = BoolType() class VoidType(Primitive): def represent(self) -> str: return 'void' Void = VoidType() class TypeType_(Primitive): def represent(self) -> str: return 'type' TypeType = TypeType_() @dataclass class TupleType(Type): types: List[Type] def is_indexable(self) -> bool: return True def represent(self) -> str: return '('+', '.join([type.represent() for type in self.types])+')' def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self self.types = [type.fill(types, stack+[id(self)]) for type in self.types] return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: is_new, new_types = self.new_fill_list(self.types, types, stack) return (is_new, TupleType(new_types)) @dataclass class ListType(Type): type: Type def is_indexable(self) -> bool: return True def represent(self) -> str: return self.type.represent()+'[]' def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self self.type = self.type.fill(types, stack+[id(self)]) return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: assert id(self) not in stack is_new, new_type = self.type.new_fill(types, stack+[id(self)]) return (is_new, ListType(new_type)) @dataclass class ArrayType(Type): type: Type number: int def is_indexable(self) -> bool: return True @dataclass class FunctionType(Type): arguments: List[Type] return_type: Type def represent(self) -> str: return '('+', '.join([type.represent() for type in self.arguments])+') -> '+self.return_type.represent() def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self self.arguments = [argument.fill(types, stack+[id(self)]) for argument in self.arguments] self.return_type = self.return_type.fill(types, stack+[id(self)]) return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: assert id(self) not in stack # TODO: Wtf? is_new_arguments, new_arguments = self.new_fill_list(self.arguments, types, stack) is_new_return_type, new_return_type = self.return_type.new_fill(types, stack+[id(self)]) return (is_new_arguments or is_new_return_type, FunctionType(new_arguments, new_return_type)) class ObjectType(Primitive): def represent(self) -> str: return 'object' Object = ObjectType() num_expressions: int = 0 @dataclass class EnumType(Type): name: str members: Dict[str, List[Type]] generics: List[Type] def __repr__(self) -> str: return self.represent() def represent(self) -> str: return self.name+('['+', '.join([generic.represent() for generic in self.generics])+']' if self.generics else '') def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self self.members = {member_name: [element.fill(types, stack+[id(self)]) for element in self.members[member_name]] for member_name in self.members} self.generics = [type.fill(types, stack+[id(self)]) for type in self.generics] return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: assert id(self) not in stack is_new = False new_members: Dict[str, List[Type]] = {} for member_name in self.members: member = self.members[member_name] is_new_member, new_members[member_name] = self.new_fill_list(member, types, stack) is_new = is_new or is_new_member return (is_new, EnumType(self.name, new_members, self.generics) if is_new else self) @dataclass class StructType(Type): name: str members: Dict[str, Type] generics: List[Type] def represent(self) -> str: assert not self.generics return self.name def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self for field in self.members: self.members[field] = self.members[field].fill(types, stack+[id(self)]) return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: assert id(self) not in stack is_new, new_members = self.new_fill_dict(self.members, types, stack) return (is_new, StructType(self.name, new_members, self.generics) if is_new else self) @dataclass class VariableType(Type): name: str def represent(self) -> str: return self.name + '?' def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: return types.get(self.name, self) def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: return (self.name in types, types.get(self.name, self)) @dataclass class GenericType(Type): variables: List[VariableType] type: Type def represent(self) -> str: assert False def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: if id(self) in stack: return self self.type = self.type.fill(types, stack+[id(self)]) return self def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: assert False def substitute(self, types: List[Type]) -> Type: assert len(types) == len(self.variables), f"{self.type.represent()} expected {len(self.variables)} type parameters, but got {len(types)}!" return self.type.new_fill({variable.name: type for (variable, type) in zip(self.variables, types)}, [])[1]