Make array literals explicitly state the element type

This commit is contained in:
germax26 2024-10-01 14:23:25 +10:00
parent e48d50f1e6
commit 04fff7514e
Signed by: germax26
SSH Key Fingerprint: SHA256:N3w+8798IMWBt7SYH8G1C0iJlIa2HIIcRCXwILT5FvM
3 changed files with 17 additions and 23 deletions

View File

@ -137,6 +137,7 @@ class ArrayAccess(Expression):
@dataclass
class Array(Expression):
element_type: TypeExpression
array: List[Expression]
def represent(self) -> str:
@ -193,6 +194,7 @@ class StructInstantiation(Expression):
@dataclass
class LoopComprehension(Expression):
element_type: TypeExpression
body: Expression
variable: str # TODO: Pattern matching
array: Expression

View File

@ -288,38 +288,26 @@ def calculate_expression(expression: Expression, program: ProgramState) -> Objec
assert False, ("Unimplemented", expression_)
case UnaryMinus (expression_):
assert False, ("Unimplemented", expression_)
case Array(array_):
if len(array_) == 0:
return ListObject(ListType(VariableType("")), [])
elements_type: Optional[Type] = None
case Array(element_type_, array_):
element_type = calculate_type_expression(element_type_, program)
array_elements_: List_[Object] = []
for element_ in array_:
element = calculate_expression(element_, program)
if elements_type:
assert element.get_type().is_subtype_of(elements_type), (element, elements_type)
else:
elements_type = element.get_type()
assert element.get_type().is_subtype_of(element_type), (element, element_type)
array_elements_.append(element)
assert elements_type
return ListObject(ListType(elements_type), array_elements_)
case LoopComprehension(body_, variable, array_):
return ListObject(ListType(element_type), array_elements_)
case LoopComprehension(element_type_, body_, variable, array_):
element_type = calculate_type_expression(element_type_, program)
array = calculate_expression(array_, program)
assert array.get_type().is_indexable()
if isinstance(array, ListObject):
elements: List_[Object] = []
elements_type = None
for element in array.list:
program.push_context({variable: Declared.from_obj(element)})
elements.append(calculate_expression(body_, program))
program.pop_context()
if elements_type:
assert elements[-1].get_type().is_subtype_of(elements_type)
else:
elements_type = elements[-1].get_type()
if not elements: return ListObject(ListType(VariableType("")), [])
assert elements_type
return ListObject(ListType(elements_type), elements)
assert elements[-1].get_type().is_subtype_of(element_type)
return ListObject(ListType(element_type), elements)
else:
assert False, ("Unimplemented", array)
case _:

View File

@ -108,21 +108,25 @@ def parse_primary(lexer: Lexer) -> Expression:
else:
base_expression = elements[0]
elif lexer.take_token(SymbolToken(Symbol.OpenSquare)):
lexer.assert_token(SymbolToken(Symbol.Colon))
element_type = parse_type(lexer)
if lexer.take_token(SymbolToken(Symbol.CloseSquare)):
base_expression = Array([])
base_expression = Array(element_type, [])
else:
lexer.assert_token(SymbolToken(Symbol.Comma))
expressions: List[Expression] = [parse_expression(lexer)]
if lexer.take_token(KeywordToken(Keyword.For)):
variable = parse_identifier(lexer) # TODO: Pattern matching
lexer.assert_token(KeywordToken(Keyword.In))
expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = LoopComprehension(expressions[0], variable, expression)
base_expression = LoopComprehension(element_type, expressions[0], variable, expression)
else:
while lexer.take_token(SymbolToken(Symbol.Comma)):
expressions.append(parse_expression(lexer))
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = Array(expressions)
base_expression = Array(element_type, expressions)
elif lexer.check_tokenkind(StringToken):
base_expression = String(parse_string(lexer))
elif lexer.check_tokenkind(NumberToken):