Make return a regular statement instead of an expression

This commit is contained in:
germax26 2024-10-01 11:51:54 +10:00
parent 18b22cd5d1
commit f3ed26f131
Signed by: germax26
SSH Key Fingerprint: SHA256:N3w+8798IMWBt7SYH8G1C0iJlIa2HIIcRCXwILT5FvM
6 changed files with 21 additions and 47 deletions

View File

@ -202,16 +202,6 @@ class LoopComprehension(Expression):
def precedence(self) -> int: return 13 def precedence(self) -> int: return 13
@dataclass
class Return(Expression):
expression: Expression
def represent(self) -> str:
# TODO: This will have to be improved
return "return "+self.wrap(self.expression)
def precedence(self) -> int: return 0
@dataclass @dataclass
class Lambda(Expression): class Lambda(Expression):
parameters: List[TypeDeclaration] parameters: List[TypeDeclaration]
@ -484,7 +474,6 @@ class DoWhileStatement(Statement):
body: Statement body: Statement
condition: Optional[Expression] condition: Optional[Expression]
# TODO: Maybe do something similar to return with these two?
@dataclass @dataclass
class BreakStatement(Statement): class BreakStatement(Statement):
pass pass
@ -493,6 +482,10 @@ class BreakStatement(Statement):
class ContinueStatement(Statement): class ContinueStatement(Statement):
pass pass
@dataclass
class ReturnStatement(Statement):
expression: Expression
@dataclass @dataclass
class MatchStatement(Statement): class MatchStatement(Statement):
value: Expression value: Expression

View File

@ -3,11 +3,11 @@ from typing import Dict, List as List_, Optional, Tuple
from ppp_ast import * from ppp_ast import *
from ppp_lexer import Lexer from ppp_lexer import Lexer
from ppp_object import Bool, EnumValue, Function, Int, Object, Str, Struct, Tuple as TupleObject, List as ListObject, Return as ReturnObject, TypeObject, Void from ppp_object import Bool, EnumValue, Function, Int, Object, Str, Struct, Tuple as TupleObject, List as ListObject, TypeObject, Void
from ppp_parser import is_valid_target, parse_statement from ppp_parser import is_valid_target, parse_statement
from ppp_tokens import EofToken from ppp_tokens import EofToken
from ppp_stdlib import variables from ppp_stdlib import variables
from ppp_types import EnumType, FunctionType, GenericType, Int as IntType, ListType, ReturnType, Str as StrType, StructType, TupleType, Type, TypeType, VariableType, Void as VoidType from ppp_types import EnumType, FunctionType, GenericType, Int as IntType, ListType, Str as StrType, StructType, TupleType, Type, TypeType, VariableType, Void as VoidType
@dataclass @dataclass
class Declared: class Declared:
@ -218,9 +218,6 @@ def calculate_expression(expression: Expression, program: ProgramState) -> Objec
case Int(num): return Str(left_value.str % num) case Int(num): return Str(left_value.str % num)
case _: assert False, ("Unimplemented", right_value) case _: assert False, ("Unimplemented", right_value)
assert False, ("Unimplemented", lhs, rhs) assert False, ("Unimplemented", lhs, rhs)
case Return(expression):
value = calculate_expression(expression, program)
return ReturnObject(ReturnType(value.get_type()), value)
case StructInstantiation(struct_, arguments_): case StructInstantiation(struct_, arguments_):
struct = calculate_expression(struct_, program) struct = calculate_expression(struct_, program)
assert isinstance(struct, TypeObject) assert isinstance(struct, TypeObject)
@ -417,8 +414,7 @@ def interpret_statements(statements: List_[Statement], program: ProgramState) ->
for statement in statements: for statement in statements:
match statement: match statement:
case ExpressionStatement(expression): case ExpressionStatement(expression):
value = calculate_expression(expression, program) calculate_expression(expression, program)
if isinstance(value, ReturnObject): return ReturnResult(value.value)
case Assignment(lhs, rhs, type_): case Assignment(lhs, rhs, type_):
assert is_valid_target(lhs) assert is_valid_target(lhs)
match lhs: match lhs:
@ -552,6 +548,8 @@ def interpret_statements(statements: List_[Statement], program: ProgramState) ->
case _: assert False, ("Unimplemented", return_value) case _: assert False, ("Unimplemented", return_value)
case ContinueStatement(): return ContinueResult() case ContinueStatement(): return ContinueResult()
case BreakStatement(): return BreakResult() case BreakStatement(): return BreakResult()
case ReturnStatement(expression=expression):
return ReturnResult(calculate_expression(expression, program))
case Import(file): case Import(file):
# TODO: Maybe an inclusion system within a preprocessor maybe # TODO: Maybe an inclusion system within a preprocessor maybe
module = interpret_file(file, program.modules) if file not in program.modules else program.modules[file] module = interpret_file(file, program.modules) if file not in program.modules else program.modules[file]
@ -578,7 +576,7 @@ def interpret_file(file_path: str, modules: Dict[str, Module]) -> Module:
assert len(program.contexts) == 2 assert len(program.contexts) == 2
match return_value: match return_value:
case NothingResult(): pass case NothingResult(): pass
case ReturnObject(_): assert False, "Cannot return from outside a function!" case ReturnResult(_): assert False, "Cannot return from outside a function!"
case ContinueResult(): assert False, "Cannot continue from outside a loop!" case ContinueResult(): assert False, "Cannot continue from outside a loop!"
case BreakResult(): assert False, "Cannot break from outside a loop!" case BreakResult(): assert False, "Cannot break from outside a loop!"
case _: assert False, ("Unimplemented", return_value) case _: assert False, ("Unimplemented", return_value)

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import Callable, Dict, List as List_, Tuple as Tuple_ from typing import Callable, Dict, List as List_, Tuple as Tuple_
from ppp_ast import Statement from ppp_ast import Statement
from ppp_types import ArrayType, EnumType, FunctionType, ListType, ReturnType, StructType, TupleType, Type, Int as IntType, Str as StrType, Bool as BoolType, Void as VoidType, TypeType from ppp_types import ArrayType, EnumType, FunctionType, ListType, StructType, TupleType, Type, Int as IntType, Str as StrType, Bool as BoolType, Void as VoidType, TypeType
class Object(ABC): class Object(ABC):
@abstractmethod @abstractmethod
@ -68,13 +68,6 @@ class Function(Object):
def get_type(self) -> Type: return self.type def get_type(self) -> Type: return self.type
@dataclass
class Return(Object):
type: ReturnType
value: Object
def get_type(self) -> Type: return self.type
@dataclass @dataclass
class EnumValue(Object): class EnumValue(Object):
type: EnumType type: EnumType

View File

@ -173,7 +173,6 @@ def parse_unary(lexer: Lexer) -> Expression:
if lexer.take_token(SymbolToken(Symbol.Exclamation)): return Not(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Exclamation)): return Not(parse_unary(lexer))
if lexer.take_token(SymbolToken(Symbol.Plus)): return UnaryPlus(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Plus)): return UnaryPlus(parse_unary(lexer))
if lexer.take_token(SymbolToken(Symbol.Dash)): return UnaryMinus(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Dash)): return UnaryMinus(parse_unary(lexer))
if lexer.take_token(KeywordToken(Keyword.Return)): return Return(parse_unary(lexer))
return parse_primary(lexer) return parse_primary(lexer)
Precedence = Dict[Symbol, Callable[[Expression, Expression], Expression]] Precedence = Dict[Symbol, Callable[[Expression, Expression], Expression]]
@ -209,7 +208,6 @@ def parse_ternary(lexer: Lexer) -> Expression:
return Ternary(expression, if_true, if_false) return Ternary(expression, if_true, if_false)
def parse_expression(lexer: Lexer) -> Expression: def parse_expression(lexer: Lexer) -> Expression:
if lexer.take_token(KeywordToken(Keyword.Return)): return Return(parse_expression(lexer))
if lexer.take_token(KeywordToken(Keyword.Lambda)): if lexer.take_token(KeywordToken(Keyword.Lambda)):
parameters: List[TypeDeclaration] parameters: List[TypeDeclaration]
if lexer.take_token(SymbolToken(Symbol.EqualArrow)): if lexer.take_token(SymbolToken(Symbol.EqualArrow)):
@ -290,6 +288,10 @@ def parse_statement(lexer: Lexer) -> Statement:
elif lexer.take_token(KeywordToken(Keyword.Continue)): elif lexer.take_token(KeywordToken(Keyword.Continue)):
lexer.assert_token(SymbolToken(Symbol.Semicolon)) lexer.assert_token(SymbolToken(Symbol.Semicolon))
return ContinueStatement() return ContinueStatement()
elif lexer.take_token(KeywordToken(Keyword.Return)):
expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return ReturnStatement(expression)
elif lexer.take_token(KeywordToken(Keyword.Do)): elif lexer.take_token(KeywordToken(Keyword.Do)):
body = parse_statement(lexer) body = parse_statement(lexer)
condition: Optional[Expression] = None condition: Optional[Expression] = None
@ -330,8 +332,10 @@ def parse_statement(lexer: Lexer) -> Statement:
elif lexer.take_token(KeywordToken(Keyword.Defer)): elif lexer.take_token(KeywordToken(Keyword.Defer)):
statement = parse_statement(lexer) statement = parse_statement(lexer)
return DeferStatement(statement) return DeferStatement(statement)
elif lexer.check_tokenkind(KeywordToken) and not lexer.check_tokens(KeywordToken(Keyword.Return), KeywordToken(Keyword.Lambda)): elif lexer.check_tokenkind(KeywordToken) and not lexer.check_token(KeywordToken(Keyword.Lambda)): # TODO: Maybe use '\' for lambda instead of a keyword
assert False, ("Unimplemented", lexer.next_token(), lexer.next_token(), lexer.next_token()) token = lexer.next_token()
assert isinstance(token.contents, KeywordToken)
raise SyntaxError(f"{token.loc}: Unexpected keyword: '{token.contents.keyword}'")
elif lexer.take_token(SymbolToken(Symbol.OpenCurly)): elif lexer.take_token(SymbolToken(Symbol.OpenCurly)):
statements: List[Statement] = [] statements: List[Statement] = []
while not lexer.take_token(SymbolToken(Symbol.CloseCurly)): while not lexer.take_token(SymbolToken(Symbol.CloseCurly)):

View File

@ -23,6 +23,8 @@ class Keyword(Enum):
Type = 'type' Type = 'type'
Defer = 'defer' Defer = 'defer'
def __str__(self) -> str: return self._value_
class Symbol(Enum): class Symbol(Enum):
Open = '(' Open = '('
Close = ')' Close = ')'

View File

@ -164,22 +164,6 @@ class ObjectType(Primitive):
def represent(self) -> str: return 'object' def represent(self) -> str: return 'object'
Object = ObjectType() Object = ObjectType()
@dataclass
class ReturnType(Type):
type: Type
def represent(self) -> str: return f"return<{self.type.represent()}>"
def fill(self, types: Dict[str, Type], stack: List[int]) -> Type:
if id(self) in stack: return self
self.type = self.type.fill(types, stack+[id(self)])
return self
def new_fill(self, types: Dict[str, Type], stack: List[int]) -> Tuple[bool, Type]:
assert id(self) not in stack
is_new, new_type = self.type.new_fill(types, stack+[id(self)])
return (is_new, ReturnType(new_type))
num_expressions: int = 0 num_expressions: int = 0
@dataclass @dataclass