2024-08-08 21:54:03 +10:00
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from dataclasses import dataclass
|
2024-08-13 12:45:42 +10:00
|
|
|
from typing import Dict, List, Tuple
|
2024-08-08 21:54:03 +10:00
|
|
|
|
|
|
|
import sys
|
|
|
|
sys.setrecursionlimit(1000)
|
|
|
|
|
|
|
|
class Type(ABC):
|
|
|
|
def is_indexable(self) -> bool: return False
|
|
|
|
|
|
|
|
def is_subtype_of(self, other: 'Type') -> bool:
|
|
|
|
match self, other:
|
|
|
|
case StrType(), StrType(): return True
|
|
|
|
case TypeType_(), TypeType_(): return True
|
|
|
|
case FunctionType(self_arguments, self_return_type), FunctionType(other_arguments, other_return_type):
|
|
|
|
assert len(self_arguments) == len(other_arguments)
|
|
|
|
for (self_argument, other_argument) in zip(self_arguments, other_arguments):
|
|
|
|
if not other_argument.is_subtype_of(self_argument): return False
|
|
|
|
return self_return_type.is_subtype_of(other_return_type)
|
2024-08-12 00:00:03 +10:00
|
|
|
case EnumType(self_name, self_members), EnumType(other_name, other_members):
|
2024-08-08 21:54:03 +10:00
|
|
|
# if self_name == other_name: assert self is other, (num_expressions, self, other, self_name, other_name, self_members, other_members)
|
|
|
|
return self is other
|
|
|
|
return self_name == other_name
|
|
|
|
case StructType(_, _), StructType(_, _): return self is other
|
|
|
|
case VoidType(), VoidType(): return True
|
|
|
|
case ListType(self_type), ListType(other_type):
|
|
|
|
# TODO: Maybe return which types match
|
|
|
|
if isinstance(self_type, VariableType): return True
|
|
|
|
if isinstance(other_type, VariableType): return True
|
|
|
|
return self_type == other_type
|
|
|
|
return self_type.is_subtype_of(other_type)
|
|
|
|
case TupleType(self_elememts), TupleType(other_elements):
|
|
|
|
if len(self_elememts) != len(other_elements): return False
|
|
|
|
for (self_element, other_element) in zip(self_elememts, other_elements):
|
|
|
|
if not self_element.is_subtype_of(other_element): return False
|
|
|
|
return True
|
|
|
|
case IntType(), IntType():
|
|
|
|
return True
|
|
|
|
case VariableType(self_name), VariableType(other_name):
|
|
|
|
return self_name == other_name
|
|
|
|
case _, VariableType(""): return True
|
|
|
|
case BoolType(), BoolType(): return True
|
|
|
|
case type, ObjectType(): return True
|
|
|
|
case type_a, type_b if type_a.__class__ != type_b.__class__: return False
|
|
|
|
case _, _: assert False, ("Unimplemented", self, other)
|
|
|
|
assert False, ("Unimplemented", self, other)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return isinstance(other, Type) and self.is_subtype_of(other) and other.is_subtype_of(self)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def represent(self) -> str: ...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def fill(self, types: 'Dict[str, Type]', stack: List[int]) -> 'Type': ...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def new_fill(self, types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, Type]': ...
|
|
|
|
|
|
|
|
def new_fill_list(self, type_list: 'List[Type]', types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, List[Type]]':
|
|
|
|
new_types = [type.new_fill(types, stack+[id(self)]) for type in type_list]
|
|
|
|
is_new = any([new_type[0] for new_type in new_types])
|
|
|
|
return (is_new, [new_type[1] for new_type in new_types])
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill_dict(self, type_dict: 'Dict[str, Type]', types: 'Dict[str, Type]', stack: List[int]) -> 'Tuple[bool, Dict[str, Type]]':
|
|
|
|
new_types = {field: type_dict[field].new_fill(types, stack+[id(self)]) for field in type_dict}
|
|
|
|
is_new = any([new_types[field][0] for field in new_types])
|
|
|
|
return (is_new, {field: new_types[field][1] for field in new_types})
|
|
|
|
|
|
|
|
class Primitive(Type):
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type: return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]: return (False, self)
|
|
|
|
|
|
|
|
class IntType(Primitive):
|
|
|
|
def represent(self) -> str: return 'int'
|
|
|
|
Int = IntType()
|
|
|
|
|
|
|
|
|
|
|
|
class StrType(Primitive):
|
|
|
|
def is_indexable(self) -> bool: return True
|
|
|
|
|
|
|
|
def represent(self) -> str: return 'str'
|
|
|
|
Str = StrType()
|
|
|
|
|
|
|
|
|
|
|
|
class BoolType(Primitive):
|
|
|
|
def represent(self) -> str: return 'bool'
|
|
|
|
Bool = BoolType()
|
|
|
|
|
|
|
|
|
|
|
|
class VoidType(Primitive):
|
|
|
|
def represent(self) -> str: return 'void'
|
|
|
|
Void = VoidType()
|
|
|
|
|
|
|
|
|
|
|
|
class TypeType_(Primitive):
|
|
|
|
def represent(self) -> str: return 'type'
|
|
|
|
TypeType = TypeType_()
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class TupleType(Type):
|
|
|
|
types: List[Type]
|
|
|
|
|
|
|
|
def is_indexable(self) -> bool: return True
|
|
|
|
|
|
|
|
def represent(self) -> str: return '('+', '.join([type.represent() for type in self.types])+')'
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
self.types = [type.fill(types, stack+[id(self)]) for type in self.types]
|
|
|
|
return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
is_new, new_types = self.new_fill_list(self.types, types, stack)
|
|
|
|
return (is_new, TupleType(new_types))
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ListType(Type):
|
|
|
|
type: Type
|
|
|
|
|
|
|
|
def is_indexable(self) -> bool: return True
|
|
|
|
|
|
|
|
def represent(self) -> str: return self.type.represent()+'[]'
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
self.type = self.type.fill(types, stack+[id(self)])
|
|
|
|
return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
assert id(self) not in stack
|
|
|
|
is_new, new_type = self.type.new_fill(types, stack+[id(self)])
|
|
|
|
return (is_new, ListType(new_type))
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ArrayType(Type):
|
|
|
|
type: Type
|
|
|
|
number: int
|
|
|
|
|
|
|
|
def is_indexable(self) -> bool: return True
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class FunctionType(Type):
|
|
|
|
arguments: List[Type]
|
|
|
|
return_type: Type
|
|
|
|
|
|
|
|
def represent(self) -> str: return '('+', '.join([type.represent() for type in self.arguments])+') -> '+self.return_type.represent()
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
self.arguments = [argument.fill(types, stack+[id(self)]) for argument in self.arguments]
|
|
|
|
self.return_type = self.return_type.fill(types, stack+[id(self)])
|
|
|
|
return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
assert id(self) not in stack # TODO: Wtf?
|
|
|
|
is_new_arguments, new_arguments = self.new_fill_list(self.arguments, types, stack)
|
|
|
|
is_new_return_type, new_return_type = self.return_type.new_fill(types, stack+[id(self)])
|
|
|
|
return (is_new_arguments or is_new_return_type, FunctionType(new_arguments, new_return_type))
|
|
|
|
|
|
|
|
class ObjectType(Primitive):
|
|
|
|
def represent(self) -> str: return 'object'
|
|
|
|
Object = ObjectType()
|
|
|
|
|
|
|
|
|
|
|
|
num_expressions: int = 0
|
|
|
|
@dataclass
|
|
|
|
class EnumType(Type):
|
|
|
|
name: str
|
|
|
|
members: Dict[str, List[Type]]
|
|
|
|
generics: List[Type]
|
|
|
|
|
|
|
|
def __repr__(self) -> str: return self.represent()
|
|
|
|
|
|
|
|
def represent(self) -> str: return self.name+('['+', '.join([generic.represent() for generic in self.generics])+']' if self.generics else '')
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
self.members = {member_name: [element.fill(types, stack+[id(self)]) for element in self.members[member_name]] for member_name in self.members}
|
|
|
|
self.generics = [type.fill(types, stack+[id(self)]) for type in self.generics]
|
|
|
|
return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
assert id(self) not in stack
|
|
|
|
is_new = False
|
|
|
|
new_members: Dict[str, List[Type]] = {}
|
|
|
|
for member_name in self.members:
|
|
|
|
member = self.members[member_name]
|
|
|
|
is_new_member, new_members[member_name] = self.new_fill_list(member, types, stack)
|
|
|
|
is_new = is_new or is_new_member
|
|
|
|
return (is_new, EnumType(self.name, new_members, self.generics) if is_new else self)
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
@dataclass
|
|
|
|
class StructType(Type):
|
|
|
|
name: str
|
|
|
|
members: Dict[str, Type]
|
|
|
|
generics: List[Type]
|
|
|
|
|
2024-08-12 00:00:03 +10:00
|
|
|
def represent(self) -> str:
|
2024-08-08 21:54:03 +10:00
|
|
|
assert not self.generics
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
for field in self.members:
|
|
|
|
self.members[field] = self.members[field].fill(types, stack+[id(self)])
|
|
|
|
return self
|
|
|
|
|
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
assert id(self) not in stack
|
|
|
|
is_new, new_members = self.new_fill_dict(self.members, types, stack)
|
|
|
|
return (is_new, StructType(self.name, new_members, self.generics) if is_new else self)
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
@dataclass
|
|
|
|
class VariableType(Type):
|
|
|
|
name: str
|
|
|
|
|
|
|
|
def represent(self) -> str: return self.name + '?'
|
|
|
|
|
2024-08-12 00:00:03 +10:00
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
2024-08-08 21:54:03 +10:00
|
|
|
return types.get(self.name, self)
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
return (self.name in types, types.get(self.name, self))
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GenericType(Type):
|
|
|
|
variables: List[VariableType]
|
|
|
|
type: Type
|
|
|
|
|
|
|
|
def represent(self) -> str: assert False
|
|
|
|
|
|
|
|
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
|
|
|
|
if id(self) in stack: return self
|
|
|
|
self.type = self.type.fill(types, stack+[id(self)])
|
|
|
|
return self
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
|
|
|
|
assert False
|
2024-08-12 00:00:03 +10:00
|
|
|
|
2024-08-08 21:54:03 +10:00
|
|
|
def substitute(self, types: List[Type]) -> Type:
|
|
|
|
assert len(types) == len(self.variables), f"{self.type.represent()} expected {len(self.variables)} type parameters, but got {len(types)}!"
|
2024-08-12 00:00:03 +10:00
|
|
|
return self.type.new_fill({variable.name: type for (variable, type) in zip(self.variables, types)}, [])[1]
|