Make array literals explicitly state the element type

This commit is contained in:
germax26 2024-10-01 14:23:25 +10:00
parent e48d50f1e6
commit 04fff7514e
Signed by: germax26
SSH Key Fingerprint: SHA256:N3w+8798IMWBt7SYH8G1C0iJlIa2HIIcRCXwILT5FvM
3 changed files with 17 additions and 23 deletions

View File

@ -137,6 +137,7 @@ class ArrayAccess(Expression):
@dataclass @dataclass
class Array(Expression): class Array(Expression):
element_type: TypeExpression
array: List[Expression] array: List[Expression]
def represent(self) -> str: def represent(self) -> str:
@ -193,6 +194,7 @@ class StructInstantiation(Expression):
@dataclass @dataclass
class LoopComprehension(Expression): class LoopComprehension(Expression):
element_type: TypeExpression
body: Expression body: Expression
variable: str # TODO: Pattern matching variable: str # TODO: Pattern matching
array: Expression array: Expression

View File

@ -288,38 +288,26 @@ def calculate_expression(expression: Expression, program: ProgramState) -> Objec
assert False, ("Unimplemented", expression_) assert False, ("Unimplemented", expression_)
case UnaryMinus (expression_): case UnaryMinus (expression_):
assert False, ("Unimplemented", expression_) assert False, ("Unimplemented", expression_)
case Array(array_): case Array(element_type_, array_):
if len(array_) == 0: element_type = calculate_type_expression(element_type_, program)
return ListObject(ListType(VariableType("")), [])
elements_type: Optional[Type] = None
array_elements_: List_[Object] = [] array_elements_: List_[Object] = []
for element_ in array_: for element_ in array_:
element = calculate_expression(element_, program) element = calculate_expression(element_, program)
if elements_type: assert element.get_type().is_subtype_of(element_type), (element, element_type)
assert element.get_type().is_subtype_of(elements_type), (element, elements_type)
else:
elements_type = element.get_type()
array_elements_.append(element) array_elements_.append(element)
assert elements_type return ListObject(ListType(element_type), array_elements_)
return ListObject(ListType(elements_type), array_elements_) case LoopComprehension(element_type_, body_, variable, array_):
case LoopComprehension(body_, variable, array_): element_type = calculate_type_expression(element_type_, program)
array = calculate_expression(array_, program) array = calculate_expression(array_, program)
assert array.get_type().is_indexable() assert array.get_type().is_indexable()
if isinstance(array, ListObject): if isinstance(array, ListObject):
elements: List_[Object] = [] elements: List_[Object] = []
elements_type = None
for element in array.list: for element in array.list:
program.push_context({variable: Declared.from_obj(element)}) program.push_context({variable: Declared.from_obj(element)})
elements.append(calculate_expression(body_, program)) elements.append(calculate_expression(body_, program))
program.pop_context() program.pop_context()
if elements_type: assert elements[-1].get_type().is_subtype_of(element_type)
assert elements[-1].get_type().is_subtype_of(elements_type) return ListObject(ListType(element_type), elements)
else:
elements_type = elements[-1].get_type()
if not elements: return ListObject(ListType(VariableType("")), [])
assert elements_type
return ListObject(ListType(elements_type), elements)
else: else:
assert False, ("Unimplemented", array) assert False, ("Unimplemented", array)
case _: case _:

View File

@ -108,21 +108,25 @@ def parse_primary(lexer: Lexer) -> Expression:
else: else:
base_expression = elements[0] base_expression = elements[0]
elif lexer.take_token(SymbolToken(Symbol.OpenSquare)): elif lexer.take_token(SymbolToken(Symbol.OpenSquare)):
lexer.assert_token(SymbolToken(Symbol.Colon))
element_type = parse_type(lexer)
if lexer.take_token(SymbolToken(Symbol.CloseSquare)): if lexer.take_token(SymbolToken(Symbol.CloseSquare)):
base_expression = Array([]) base_expression = Array(element_type, [])
else: else:
lexer.assert_token(SymbolToken(Symbol.Comma))
expressions: List[Expression] = [parse_expression(lexer)] expressions: List[Expression] = [parse_expression(lexer)]
if lexer.take_token(KeywordToken(Keyword.For)): if lexer.take_token(KeywordToken(Keyword.For)):
variable = parse_identifier(lexer) # TODO: Pattern matching variable = parse_identifier(lexer) # TODO: Pattern matching
lexer.assert_token(KeywordToken(Keyword.In)) lexer.assert_token(KeywordToken(Keyword.In))
expression = parse_expression(lexer) expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare)) lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = LoopComprehension(expressions[0], variable, expression) base_expression = LoopComprehension(element_type, expressions[0], variable, expression)
else: else:
while lexer.take_token(SymbolToken(Symbol.Comma)): while lexer.take_token(SymbolToken(Symbol.Comma)):
expressions.append(parse_expression(lexer)) expressions.append(parse_expression(lexer))
lexer.assert_token(SymbolToken(Symbol.CloseSquare)) lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = Array(expressions) base_expression = Array(element_type, expressions)
elif lexer.check_tokenkind(StringToken): elif lexer.check_tokenkind(StringToken):
base_expression = String(parse_string(lexer)) base_expression = String(parse_string(lexer))
elif lexer.check_tokenkind(NumberToken): elif lexer.check_tokenkind(NumberToken):