Simplify addressing

This commit is contained in:
nub31
2025-09-08 20:22:44 +02:00
parent 3295f76001
commit 343d515f37
5 changed files with 143 additions and 110 deletions

View File

@@ -19,7 +19,7 @@ struct Human : Printable
func main(args: []cstring): i64 func main(args: []cstring): i64
{ {
let x: Printable = struct Human { let x = struct Human {
name = "Oliver" name = "Oliver"
age = 23 age = 23
} }

View File

@@ -230,7 +230,7 @@ public class QBEGenerator
return; return;
} }
var value = EmitUnwrap(EmitExpression(source)); var value = EmitExpression(source);
if (source.Type.IsSimpleType(out var simpleType, out var complexType)) if (source.Type.IsSimpleType(out var simpleType, out var complexType))
{ {
@@ -269,7 +269,7 @@ public class QBEGenerator
case InterfaceInitializerNode: case InterfaceInitializerNode:
case LiteralNode { Kind: LiteralKind.String }: case LiteralNode { Kind: LiteralKind.String }:
{ {
destination = EmitUnwrap(EmitExpression(source)); destination = EmitExpression(source);
return true; return true;
} }
} }
@@ -286,7 +286,7 @@ public class QBEGenerator
return uncopiedValue; return uncopiedValue;
} }
var value = EmitUnwrap(EmitExpression(source)); var value = EmitExpression(source);
if (source.Type.IsSimpleType(out _, out var complexType)) if (source.Type.IsSimpleType(out _, out var complexType))
{ {
@@ -508,13 +508,7 @@ public class QBEGenerator
private void EmitAssignment(AssignmentNode assignment) private void EmitAssignment(AssignmentNode assignment)
{ {
if (!assignment.Target.IsLValue) EmitCopyIntoOrInitialize(assignment.Value, EmitAddressOfLValue(assignment.Target));
{
throw new UnreachableException("Destination of assignment must be an lvalue. This should have been caught in the type checker");
}
var destination = EmitExpression(assignment.Target);
EmitCopyIntoOrInitialize(assignment.Value, destination.Name);
} }
private void EmitBreak() private void EmitBreak()
@@ -535,7 +529,7 @@ public class QBEGenerator
var falseLabel = LabelName(); var falseLabel = LabelName();
var endLabel = LabelName(); var endLabel = LabelName();
var result = EmitUnwrap(EmitExpression(ifStatement.Condition)); var result = EmitExpression(ifStatement.Condition);
_writer.Indented($"jnz {result}, {trueLabel}, {falseLabel}"); _writer.Indented($"jnz {result}, {trueLabel}, {falseLabel}");
_writer.WriteLine(trueLabel); _writer.WriteLine(trueLabel);
EmitBlock(ifStatement.Body); EmitBlock(ifStatement.Body);
@@ -543,11 +537,7 @@ public class QBEGenerator
_writer.WriteLine(falseLabel); _writer.WriteLine(falseLabel);
if (ifStatement.Else.HasValue) if (ifStatement.Else.HasValue)
{ {
ifStatement.Else.Value.Match ifStatement.Else.Value.Match(EmitIf, EmitBlock);
(
elseIfNode => EmitIf(elseIfNode),
elseNode => EmitBlock(elseNode)
);
} }
_writer.WriteLine(endLabel); _writer.WriteLine(endLabel);
@@ -557,7 +547,7 @@ public class QBEGenerator
{ {
if (@return.Value.HasValue) if (@return.Value.HasValue)
{ {
var result = EmitUnwrap(EmitExpression(@return.Value.Value)); var result = EmitExpression(@return.Value.Value);
_writer.Indented($"ret {result}"); _writer.Indented($"ret {result}");
} }
else else
@@ -591,7 +581,7 @@ public class QBEGenerator
_writer.WriteLine(iterationLabel); _writer.WriteLine(iterationLabel);
EmitBlock(whileStatement.Body); EmitBlock(whileStatement.Body);
_writer.WriteLine(conditionLabel); _writer.WriteLine(conditionLabel);
var result = EmitUnwrap(EmitExpression(whileStatement.Condition)); var result = EmitExpression(whileStatement.Condition);
_writer.Indented($"jnz {result}, {iterationLabel}, {endLabel}"); _writer.Indented($"jnz {result}, {iterationLabel}, {endLabel}");
_writer.WriteLine(endLabel); _writer.WriteLine(endLabel);
@@ -599,21 +589,22 @@ public class QBEGenerator
_breakLabels.Pop(); _breakLabels.Pop();
} }
private Val EmitExpression(ExpressionNode expression) private string EmitExpression(ExpressionNode expression)
{ {
var value = expression switch return expression switch
{ {
ArrayInitializerNode arrayInitializer => EmitArrayInitializer(arrayInitializer), ArrayInitializerNode arrayInitializer => EmitArrayInitializer(arrayInitializer),
StructInitializerNode structInitializer => EmitStructInitializer(structInitializer), StructInitializerNode structInitializer => EmitStructInitializer(structInitializer),
AddressOfNode addressOf => EmitAddressOf(addressOf), AddressOfNode addressOf => EmitAddressOf(addressOf),
DereferenceNode dereference => EmitDereference(dereference), DereferenceNode dereference => EmitDereference(dereference),
BinaryExpressionNode binaryExpression => EmitBinaryExpression(binaryExpression), BinaryExpressionNode binary => EmitBinaryExpression(binary),
FuncCallNode funcCallExpression => EmitFuncCall(funcCallExpression), FuncCallNode funcCall => EmitFuncCall(funcCall),
InterfaceFuncCallNode interfaceFuncCall => EmitInterfaceFuncCall(interfaceFuncCall), InterfaceFuncCallNode interfaceFuncCall => EmitInterfaceFuncCall(interfaceFuncCall),
InterfaceInitializerNode interfaceInitializer => EmitInterfaceInitializer(interfaceInitializer), InterfaceInitializerNode interfaceInitializer => EmitInterfaceInitializer(interfaceInitializer),
ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent), ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent),
LocalFuncIdentNode localFuncIdent => EmitLocalFuncIdent(localFuncIdent), LocalFuncIdentNode localFuncIdent => EmitLocalFuncIdent(localFuncIdent),
VariableIdentNode variableIdent => EmitVariableIdent(variableIdent), VariableIdentNode variableIdent => EmitVariableIdent(variableIdent),
FuncParameterIdentNode funcParameterIdent => EmitFuncParameterIdent(funcParameterIdent),
LiteralNode literal => EmitLiteral(literal), LiteralNode literal => EmitLiteral(literal),
UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression), UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression),
StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess), StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess),
@@ -621,27 +612,14 @@ public class QBEGenerator
ArrayIndexAccessNode arrayIndex => EmitArrayIndexAccess(arrayIndex), ArrayIndexAccessNode arrayIndex => EmitArrayIndexAccess(arrayIndex),
_ => throw new ArgumentOutOfRangeException(nameof(expression)) _ => throw new ArgumentOutOfRangeException(nameof(expression))
}; };
return new Val(value, expression.Type, expression.IsLValue);
} }
private string EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess) private string EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
{ {
var array = EmitUnwrap(EmitExpression(arrayIndexAccess.Target)); return EmitLoad(arrayIndexAccess.Type, EmitAddressOfArrayIndexAccess(arrayIndexAccess));
var index = EmitUnwrap(EmitExpression(arrayIndexAccess.Index));
EmitArraysCheck(array, index);
var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType;
var pointer = TmpName();
_writer.Indented($"{pointer} =l mul {index}, {SizeOf(elementType)}");
_writer.Indented($"{pointer} =l add {pointer}, 8");
_writer.Indented($"{pointer} =l add {array}, {pointer}");
return pointer;
} }
private void EmitArraysCheck(string array, string index) private void EmitArrayBoundsCheck(string array, string index)
{ {
var count = TmpName(); var count = TmpName();
_writer.Indented($"{count} =l loadl {array}"); _writer.Indented($"{count} =l loadl {array}");
@@ -667,7 +645,7 @@ public class QBEGenerator
private string EmitArrayInitializer(ArrayInitializerNode arrayInitializer) private string EmitArrayInitializer(ArrayInitializerNode arrayInitializer)
{ {
var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity)); var capacity = EmitExpression(arrayInitializer.Capacity);
var elementSize = SizeOf(arrayInitializer.ElementType); var elementSize = SizeOf(arrayInitializer.ElementType);
var capacityInBytes = TmpName(); var capacityInBytes = TmpName();
@@ -688,24 +666,65 @@ public class QBEGenerator
private string EmitDereference(DereferenceNode dereference) private string EmitDereference(DereferenceNode dereference)
{ {
return EmitLoad(dereference.Type, EmitUnwrap(EmitExpression(dereference.Expression))); return EmitLoad(dereference.Type, EmitExpression(dereference.Expression));
} }
private string EmitAddressOf(AddressOfNode addressOf) private string EmitAddressOf(AddressOfNode addressOf)
{ {
var value = EmitExpression(addressOf.Expression); return EmitAddressOfLValue(addressOf.LValue);
if (!value.IsLValue)
{
throw new UnreachableException("Tried to take address of rvalue. This should have been caught in the type checker");
} }
return value.Name; private string EmitAddressOfLValue(LValueExpressionNode addressOf)
{
return addressOf switch
{
ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess),
ArrayInitializerNode arrayInitializer => EmitArrayInitializer(arrayInitializer),
InterfaceInitializerNode interfaceInitializer => EmitInterfaceInitializer(interfaceInitializer),
StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess),
StructInitializerNode structInitializer => EmitStructInitializer(structInitializer),
VariableIdentNode variableIdent => EmitAddressOfVariableIdent(variableIdent),
_ => throw new ArgumentOutOfRangeException(nameof(addressOf))
};
}
private string EmitAddressOfArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
{
var array = EmitExpression(arrayIndexAccess.Target);
var index = EmitExpression(arrayIndexAccess.Index);
EmitArrayBoundsCheck(array, index);
var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType;
var address = TmpName();
_writer.Indented($"{address} =l mul {index}, {SizeOf(elementType)}");
_writer.Indented($"{address} =l add {address}, 8");
_writer.Indented($"{address} =l add {array}, {address}");
return address;
}
private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess)
{
var target = EmitExpression(structFieldAccess.Target);
var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name);
var offset = OffsetOf(structDef, structFieldAccess.Field);
var address = TmpName();
_writer.Indented($"{address} =l add {target}, {offset}");
return address;
}
private string EmitAddressOfVariableIdent(VariableIdentNode variableIdent)
{
return "%" + variableIdent.Name;
} }
private string EmitBinaryExpression(BinaryExpressionNode binaryExpression) private string EmitBinaryExpression(BinaryExpressionNode binaryExpression)
{ {
var left = EmitUnwrap(EmitExpression(binaryExpression.Left)); var left = EmitExpression(binaryExpression.Left);
var right = EmitUnwrap(EmitExpression(binaryExpression.Right)); var right = EmitExpression(binaryExpression.Right);
var outputName = TmpName(); var outputName = TmpName();
@@ -811,19 +830,22 @@ public class QBEGenerator
private string EmitExternFuncIdent(ExternFuncIdentNode externFuncIdent) private string EmitExternFuncIdent(ExternFuncIdentNode externFuncIdent)
{ {
var func = _definitionTable.LookupExternFunc(externFuncIdent.Name); return ExternFuncName(_definitionTable.LookupExternFunc(externFuncIdent.Name));
return ExternFuncName(func);
} }
private string EmitLocalFuncIdent(LocalFuncIdentNode localFuncIdent) private string EmitLocalFuncIdent(LocalFuncIdentNode localFuncIdent)
{ {
var func = _definitionTable.LookupLocalFunc(localFuncIdent.Name); return LocalFuncName(_definitionTable.LookupLocalFunc(localFuncIdent.Name));
return LocalFuncName(func);
} }
private string EmitVariableIdent(VariableIdentNode variableIdent) private string EmitVariableIdent(VariableIdentNode variableIdent)
{ {
return variableIdent.Name; return EmitLoad(variableIdent.Type, EmitAddressOfVariableIdent(variableIdent));
}
private string EmitFuncParameterIdent(FuncParameterIdentNode funcParameterIdent)
{
return "%" + funcParameterIdent.Name;
} }
private string EmitLiteral(LiteralNode literal) private string EmitLiteral(LiteralNode literal)
@@ -941,7 +963,7 @@ public class QBEGenerator
private string EmitUnaryExpression(UnaryExpressionNode unaryExpression) private string EmitUnaryExpression(UnaryExpressionNode unaryExpression)
{ {
var operand = EmitUnwrap(EmitExpression(unaryExpression.Operand)); var operand = EmitExpression(unaryExpression.Operand);
var outputName = TmpName(); var outputName = TmpName();
switch (unaryExpression.Operator) switch (unaryExpression.Operator)
@@ -988,15 +1010,7 @@ public class QBEGenerator
private string EmitStructFieldAccess(StructFieldAccessNode structFieldAccess) private string EmitStructFieldAccess(StructFieldAccessNode structFieldAccess)
{ {
var target = EmitUnwrap(EmitExpression(structFieldAccess.Target)); return EmitLoad(structFieldAccess.Type, EmitAddressOfStructFieldAccess(structFieldAccess));
var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name);
var offset = OffsetOf(structDef, structFieldAccess.Field);
var output = TmpName();
_writer.Indented($"{output} =l add {target}, {offset}");
return output;
} }
private string EmitStructFuncCall(StructFuncCallNode structFuncCall) private string EmitStructFuncCall(StructFuncCallNode structFuncCall)
@@ -1004,7 +1018,7 @@ public class QBEGenerator
var structDef = _definitionTable.LookupStruct(structFuncCall.StructType.Name); var structDef = _definitionTable.LookupStruct(structFuncCall.StructType.Name);
var func = StructFuncName(structDef.Name, structFuncCall.Name); var func = StructFuncName(structDef.Name, structFuncCall.Name);
var thisParameter = EmitUnwrap(EmitExpression(structFuncCall.StructExpression)); var thisParameter = EmitExpression(structFuncCall.StructExpression);
List<string> parameterStrings = [$"l {thisParameter}"]; List<string> parameterStrings = [$"l {thisParameter}"];
@@ -1029,7 +1043,7 @@ public class QBEGenerator
private string EmitInterfaceFuncCall(InterfaceFuncCallNode interfaceFuncCall) private string EmitInterfaceFuncCall(InterfaceFuncCallNode interfaceFuncCall)
{ {
var target = EmitUnwrap(EmitExpression(interfaceFuncCall.InterfaceExpression)); var target = EmitExpression(interfaceFuncCall.InterfaceExpression);
var interfaceDef = _definitionTable.LookupInterface(interfaceFuncCall.InterfaceType.Name); var interfaceDef = _definitionTable.LookupInterface(interfaceFuncCall.InterfaceType.Name);
var functionIndex = interfaceDef.Functions.ToList().FindIndex(x => x.Name == interfaceFuncCall.Name); var functionIndex = interfaceDef.Functions.ToList().FindIndex(x => x.Name == interfaceFuncCall.Name);
@@ -1070,7 +1084,7 @@ public class QBEGenerator
private string EmitInterfaceInitializer(InterfaceInitializerNode interfaceInitializer, string? destination = null) private string EmitInterfaceInitializer(InterfaceInitializerNode interfaceInitializer, string? destination = null)
{ {
var implementation = EmitUnwrap(EmitExpression(interfaceInitializer.Implementation)); var implementation = EmitExpression(interfaceInitializer.Implementation);
var vtableOffset = 0; var vtableOffset = 0;
foreach (var interfaceImplementation in interfaceInitializer.StructType.InterfaceImplementations) foreach (var interfaceImplementation in interfaceInitializer.StructType.InterfaceImplementations)
@@ -1102,7 +1116,7 @@ public class QBEGenerator
private string EmitFuncCall(FuncCallNode funcCall) private string EmitFuncCall(FuncCallNode funcCall)
{ {
var expression = EmitExpression(funcCall.Expression); var funcPointer = EmitExpression(funcCall.Expression);
var parameterStrings = new List<string>(); var parameterStrings = new List<string>();
@@ -1112,7 +1126,6 @@ public class QBEGenerator
parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}");
} }
var funcPointer = EmitUnwrap(expression);
if (funcCall.Type is VoidTypeNode) if (funcCall.Type is VoidTypeNode)
{ {
_writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})"); _writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})");
@@ -1126,11 +1139,6 @@ public class QBEGenerator
} }
} }
private string EmitUnwrap(Val val)
{
return val.IsLValue ? EmitLoad(val.Type, val.Name) : val.Name;
}
private static int SizeOf(TypeNode type) private static int SizeOf(TypeNode type)
{ {
return type switch return type switch
@@ -1294,5 +1302,3 @@ public class CStringLiteral(string value, string name)
public string Value { get; } = value; public string Value { get; } = value;
public string Name { get; } = name; public string Name { get; } = name;
} }
public record Val(string Name, TypeNode Type, bool IsLValue);

View File

@@ -22,36 +22,41 @@ public enum BinaryOperator
Divide Divide
} }
public abstract record ExpressionNode(TypeNode Type, bool IsLValue) : Node; public abstract record ExpressionNode(TypeNode Type) : Node;
public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : ExpressionNode(Type, false); public abstract record LValueExpressionNode(TypeNode Type) : RValueExpressionNode(Type);
public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type);
public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : ExpressionNode(Type, false); public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type);
public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters) : ExpressionNode(Type, false); public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type);
public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters) : ExpressionNode(Type, false); public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters) : ExpressionNode(Type, false); public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record VariableIdentNode(TypeNode Type, string Name) : ExpressionNode(Type, true); public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record LocalFuncIdentNode(TypeNode Type, string Name) : ExpressionNode(Type, false); public record VariableIdentNode(TypeNode Type, string Name) : LValueExpressionNode(Type);
public record ExternFuncIdentNode(TypeNode Type, string Name) : ExpressionNode(Type, false); public record FuncParameterIdentNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : ExpressionNode(Type, true); public record LocalFuncIdentNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : ExpressionNode(Type, true); public record ExternFuncIdentNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record AddressOfNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type, false); public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : LValueExpressionNode(Type);
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : ExpressionNode(Type, false); public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type);
public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type, true); public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type);
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : ExpressionNode(StructType, true); public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : RValueExpressionNode(Type);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type, false); public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type);
public record InterfaceInitializerNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : ExpressionNode(Type, true); public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : LValueExpressionNode(StructType);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type);
public record InterfaceInitializerNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : LValueExpressionNode(Type);

View File

@@ -6,7 +6,7 @@ public record StatementExpressionNode(ExpressionNode Expression) : StatementNode
public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode; public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode;
public record AssignmentNode(ExpressionNode Target, ExpressionNode Value) : StatementNode; public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode;
public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else) : StatementNode; public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else) : StatementNode;

View File

@@ -94,10 +94,10 @@ public sealed class TypeChecker
{ {
var scope = new Scope(); var scope = new Scope();
scope.Declare(new Variable("this", GetStructType(node))); scope.Declare(new Identifier("this", GetStructType(node), IdentifierKind.FunctionParameter));
foreach (var parameter in func.Signature.Parameters) foreach (var parameter in func.Signature.Parameters)
{ {
scope.Declare(new Variable(parameter.Name, CheckType(parameter.Type))); scope.Declare(new Identifier(parameter.Name, CheckType(parameter.Type), IdentifierKind.FunctionParameter));
} }
_funcReturnTypes.Push(CheckType(func.Signature.ReturnType)); _funcReturnTypes.Push(CheckType(func.Signature.ReturnType));
@@ -150,7 +150,7 @@ public sealed class TypeChecker
var scope = new Scope(); var scope = new Scope();
foreach (var parameter in signature.Parameters) foreach (var parameter in signature.Parameters)
{ {
scope.Declare(new Variable(parameter.Name, parameter.Type)); scope.Declare(new Identifier(parameter.Name, parameter.Type, IdentifierKind.FunctionParameter));
} }
_funcReturnTypes.Push(signature.ReturnType); _funcReturnTypes.Push(signature.ReturnType);
@@ -178,9 +178,14 @@ public sealed class TypeChecker
private StatementNode CheckAssignment(AssignmentSyntax statement) private StatementNode CheckAssignment(AssignmentSyntax statement)
{ {
var expression = CheckExpression(statement.Target); var target = CheckExpression(statement.Target);
var value = CheckExpression(statement.Value, expression.Type); if (target is not LValueExpressionNode targetLValue)
return new AssignmentNode(expression, value); {
throw new TypeCheckerException(Diagnostic.Error("Cannot assign to rvalue").Build());
}
var value = CheckExpression(statement.Value, target.Type);
return new AssignmentNode(targetLValue, value);
} }
private IfNode CheckIf(IfSyntax statement) private IfNode CheckIf(IfSyntax statement)
@@ -252,7 +257,7 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error($"Unknown type of variable {statement.Name}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Unknown type of variable {statement.Name}").Build());
} }
Scope.Declare(new Variable(statement.Name, type)); Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable));
return new VariableDeclarationNode(statement.Name, assignment, type); return new VariableDeclarationNode(statement.Name, assignment, type);
} }
@@ -306,7 +311,13 @@ public sealed class TypeChecker
private AddressOfNode CheckAddressOf(AddressOfSyntax expression) private AddressOfNode CheckAddressOf(AddressOfSyntax expression)
{ {
var inner = CheckExpression(expression.Expression); var inner = CheckExpression(expression.Expression);
return new AddressOfNode(new PointerTypeNode(inner.Type), inner);
if (inner is not LValueExpressionNode lValueInner)
{
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of rvalue").Build());
}
return new AddressOfNode(new PointerTypeNode(inner.Type), lValueInner);
} }
private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
@@ -434,10 +445,15 @@ public sealed class TypeChecker
private ExpressionNode CheckIdentifier(IdentifierSyntax expression) private ExpressionNode CheckIdentifier(IdentifierSyntax expression)
{ {
var variable = Scope.Lookup(expression.Name); var identifier = Scope.Lookup(expression.Name);
if (variable != null) if (identifier != null)
{ {
return new VariableIdentNode(variable.Type, variable.Name); return identifier.Kind switch
{
IdentifierKind.Variable => new VariableIdentNode(identifier.Type, identifier.Name),
IdentifierKind.FunctionParameter => new FuncParameterIdentNode(identifier.Type, identifier.Name),
_ => throw new ArgumentOutOfRangeException()
};
} }
var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray(); var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray();
@@ -768,13 +784,19 @@ public sealed class TypeChecker
} }
} }
public record Variable(string Name, TypeNode Type); public enum IdentifierKind
{
Variable,
FunctionParameter
}
public record Identifier(string Name, TypeNode Type, IdentifierKind Kind);
public class Scope(Scope? parent = null) public class Scope(Scope? parent = null)
{ {
private readonly List<Variable> _variables = []; private readonly List<Identifier> _variables = [];
public Variable? Lookup(string name) public Identifier? Lookup(string name)
{ {
var variable = _variables.FirstOrDefault(x => x.Name == name); var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null) if (variable != null)
@@ -785,9 +807,9 @@ public class Scope(Scope? parent = null)
return parent?.Lookup(name); return parent?.Lookup(name);
} }
public void Declare(Variable variable) public void Declare(Identifier identifier)
{ {
_variables.Add(variable); _variables.Add(identifier);
} }
public Scope SubScope() public Scope SubScope()