WIP: dev #1

Draft
nub31 wants to merge 103 commits from dev into master
4 changed files with 780 additions and 73 deletions
Showing only changes of commit 3db412a060 - Show all commits

View File

@@ -44,6 +44,17 @@ public sealed class DiagnosticBuilder(DiagnosticSeverity severity, string messag
return this; return this;
} }
public DiagnosticBuilder At(string fileName, TypedNode? node)
{
if (node != null && node.Tokens.Count != 0)
{
// todo(nub31): Calculate length based on last token
At(fileName, node.Tokens[0]);
}
return this;
}
public DiagnosticBuilder WithHelp(string helpMessage) public DiagnosticBuilder WithHelp(string helpMessage)
{ {
help = helpMessage; help = helpMessage;

View File

@@ -2,14 +2,15 @@
namespace Compiler; namespace Compiler;
public sealed class Generator(List<NodeDefinition> nodes) public sealed class Generator(TypedAst ast)
{ {
public static string Emit(List<NodeDefinition> nodes) public static string Emit(TypedAst ast)
{ {
return new Generator(nodes).Emit(); return new Generator(ast).Emit();
} }
private IndentedTextWriter writer = new(); private IndentedTextWriter writer = new();
private Dictionary<NubTypeStruct, string> structTypeNames = new();
private string Emit() private string Emit()
{ {
@@ -27,23 +28,22 @@ public sealed class Generator(List<NodeDefinition> nodes)
"""); """);
foreach (var node in nodes.OfType<NodeDefinitionFunc>()) for (var i = 0; i < ast.StructTypes.Count; i++)
{ {
var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident)); var structType = ast.StructTypes[i];
writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)});"); structTypeNames[structType] = $"s{i}";
} }
writer.WriteLine(); foreach (var structType in ast.StructTypes)
foreach (var node in nodes.OfType<NodeDefinitionStruct>())
{ {
writer.WriteLine($"struct {node.Name.Ident}"); var name = structTypeNames[structType];
writer.WriteLine($"struct {name}");
writer.WriteLine("{"); writer.WriteLine("{");
using (writer.Indent()) using (writer.Indent())
{ {
foreach (var field in node.Fields) foreach (var field in structType.Fields)
{ {
writer.WriteLine($"{CType(field.Type, field.Name.Ident)};"); writer.WriteLine($"{CType(field.Type, field.Name)};");
} }
} }
@@ -52,7 +52,15 @@ public sealed class Generator(List<NodeDefinition> nodes)
writer.WriteLine(); writer.WriteLine();
foreach (var node in nodes.OfType<NodeDefinitionFunc>()) foreach (var node in ast.Functions)
{
var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident));
writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)});");
}
writer.WriteLine();
foreach (var node in ast.Functions)
{ {
var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident)); var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident));
writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)})"); writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)})");
@@ -69,29 +77,29 @@ public sealed class Generator(List<NodeDefinition> nodes)
return writer.ToString(); return writer.ToString();
} }
private void EmitStatement(NodeStatement node) private void EmitStatement(TypedNodeStatement node)
{ {
switch (node) switch (node)
{ {
case NodeStatementBlock statement: case TypedNodeStatementBlock statement:
EmitStatementBlock(statement); EmitStatementBlock(statement);
break; break;
case NodeStatementFuncCall statement: case TypedNodeStatementFuncCall statement:
EmitStatementFuncCall(statement); EmitStatementFuncCall(statement);
break; break;
case NodeStatementReturn statement: case TypedNodeStatementReturn statement:
EmitStatementReturn(statement); EmitStatementReturn(statement);
break; break;
case NodeStatementVariableDeclaration statement: case TypedNodeStatementVariableDeclaration statement:
EmitStatementVariableDeclaration(statement); EmitStatementVariableDeclaration(statement);
break; break;
case NodeStatementAssignment statement: case TypedNodeStatementAssignment statement:
EmitStatementAssignment(statement); EmitStatementAssignment(statement);
break; break;
case NodeStatementIf statement: case TypedNodeStatementIf statement:
EmitStatementIf(statement); EmitStatementIf(statement);
break; break;
case NodeStatementWhile statement: case TypedNodeStatementWhile statement:
EmitStatementWhile(statement); EmitStatementWhile(statement);
break; break;
default: default:
@@ -99,7 +107,7 @@ public sealed class Generator(List<NodeDefinition> nodes)
} }
} }
private void EmitStatementBlock(NodeStatementBlock node) private void EmitStatementBlock(TypedNodeStatementBlock node)
{ {
writer.WriteLine("{"); writer.WriteLine("{");
using (writer.Indent()) using (writer.Indent())
@@ -111,33 +119,33 @@ public sealed class Generator(List<NodeDefinition> nodes)
writer.WriteLine("}"); writer.WriteLine("}");
} }
private void EmitStatementFuncCall(NodeStatementFuncCall node) private void EmitStatementFuncCall(TypedNodeStatementFuncCall node)
{ {
var name = EmitExpression(node.Target); var name = EmitExpression(node.Target);
var parameterValues = node.Parameters.Select(EmitExpression).ToList(); var parameterValues = node.Parameters.Select(EmitExpression).ToList();
writer.WriteLine($"{name}({string.Join(", ", parameterValues)});"); writer.WriteLine($"{name}({string.Join(", ", parameterValues)});");
} }
private void EmitStatementReturn(NodeStatementReturn statement) private void EmitStatementReturn(TypedNodeStatementReturn statement)
{ {
var value = EmitExpression(statement.Value); var value = EmitExpression(statement.Value);
writer.WriteLine($"return {value};"); writer.WriteLine($"return {value};");
} }
private void EmitStatementVariableDeclaration(NodeStatementVariableDeclaration statement) private void EmitStatementVariableDeclaration(TypedNodeStatementVariableDeclaration statement)
{ {
var value = EmitExpression(statement.Value); var value = EmitExpression(statement.Value);
writer.WriteLine($"{CType(statement.Type)} {statement.Name.Ident} = {value};"); writer.WriteLine($"{CType(statement.Type)} {statement.Name.Ident} = {value};");
} }
private void EmitStatementAssignment(NodeStatementAssignment statement) private void EmitStatementAssignment(TypedNodeStatementAssignment statement)
{ {
var target = EmitExpression(statement.Target); var target = EmitExpression(statement.Target);
var value = EmitExpression(statement.Value); var value = EmitExpression(statement.Value);
writer.WriteLine($"{target} = {value};"); writer.WriteLine($"{target} = {value};");
} }
private void EmitStatementIf(NodeStatementIf statement) private void EmitStatementIf(TypedNodeStatementIf statement)
{ {
var condition = EmitExpression(statement.Condition); var condition = EmitExpression(statement.Condition);
writer.WriteLine($"if ({condition})"); writer.WriteLine($"if ({condition})");
@@ -152,7 +160,7 @@ public sealed class Generator(List<NodeDefinition> nodes)
if (statement.ElseBlock != null) if (statement.ElseBlock != null)
{ {
writer.Write("else"); writer.Write("else");
if (statement.ElseBlock is NodeStatementIf) if (statement.ElseBlock is TypedNodeStatementIf)
writer.Write(" "); writer.Write(" ");
else else
writer.WriteLine(); writer.WriteLine();
@@ -167,7 +175,7 @@ public sealed class Generator(List<NodeDefinition> nodes)
} }
} }
private void EmitStatementWhile(NodeStatementWhile statement) private void EmitStatementWhile(TypedNodeStatementWhile statement)
{ {
var condition = EmitExpression(statement.Condition); var condition = EmitExpression(statement.Condition);
writer.WriteLine($"while ({condition})"); writer.WriteLine($"while ({condition})");
@@ -180,61 +188,61 @@ public sealed class Generator(List<NodeDefinition> nodes)
writer.WriteLine("}"); writer.WriteLine("}");
} }
private string EmitExpression(NodeExpression node) private string EmitExpression(TypedNodeExpression node)
{ {
return node switch return node switch
{ {
NodeExpressionBinary expression => EmitExpressionBinary(expression), TypedNodeExpressionBinary expression => EmitExpressionBinary(expression),
NodeExpressionUnary expression => EmitExpressionUnary(expression), TypedNodeExpressionUnary expression => EmitExpressionUnary(expression),
NodeExpressionBoolLiteral expression => expression.Value.Value ? "true" : "false", TypedNodeExpressionBoolLiteral expression => expression.Value.Value ? "true" : "false",
NodeExpressionIntLiteral expression => expression.Value.Value.ToString(), TypedNodeExpressionIntLiteral expression => expression.Value.Value.ToString(),
NodeExpressionStringLiteral expression => $"(struct string){{ \"{expression.Value.Value}\", {expression.Value.Value.Length} }}", TypedNodeExpressionStringLiteral expression => $"(struct string){{ \"{expression.Value.Value}\", {expression.Value.Value.Length} }}",
NodeExpressionStructLiteral expression => EmitExpressionStructLiteral(expression), TypedNodeExpressionStructLiteral expression => EmitExpressionStructLiteral(expression),
NodeExpressionMemberAccess expression => EmitExpressionMemberAccess(expression), TypedNodeExpressionMemberAccess expression => EmitExpressionMemberAccess(expression),
NodeExpressionIdent expression => expression.Value.Ident, TypedNodeExpressionIdent expression => expression.Value.Ident,
_ => throw new ArgumentOutOfRangeException(nameof(node), node, null) _ => throw new ArgumentOutOfRangeException(nameof(node), node, null)
}; };
} }
private string EmitExpressionBinary(NodeExpressionBinary expression) private string EmitExpressionBinary(TypedNodeExpressionBinary expression)
{ {
var left = EmitExpression(expression.Left); var left = EmitExpression(expression.Left);
var right = EmitExpression(expression.Right); var right = EmitExpression(expression.Right);
return expression.Operation switch return expression.Operation switch
{ {
NodeExpressionBinary.Op.Add => $"({left} + {right})", TypedNodeExpressionBinary.Op.Add => $"({left} + {right})",
NodeExpressionBinary.Op.Subtract => $"({left} - {right})", TypedNodeExpressionBinary.Op.Subtract => $"({left} - {right})",
NodeExpressionBinary.Op.Multiply => $"({left} * {right})", TypedNodeExpressionBinary.Op.Multiply => $"({left} * {right})",
NodeExpressionBinary.Op.Divide => $"({left} / {right})", TypedNodeExpressionBinary.Op.Divide => $"({left} / {right})",
NodeExpressionBinary.Op.Modulo => $"({left} % {right})", TypedNodeExpressionBinary.Op.Modulo => $"({left} % {right})",
NodeExpressionBinary.Op.Equal => $"({left} == {right})", TypedNodeExpressionBinary.Op.Equal => $"({left} == {right})",
NodeExpressionBinary.Op.NotEqual => $"({left} != {right})", TypedNodeExpressionBinary.Op.NotEqual => $"({left} != {right})",
NodeExpressionBinary.Op.LessThan => $"({left} < {right})", TypedNodeExpressionBinary.Op.LessThan => $"({left} < {right})",
NodeExpressionBinary.Op.LessThanOrEqual => $"({left} <= {right})", TypedNodeExpressionBinary.Op.LessThanOrEqual => $"({left} <= {right})",
NodeExpressionBinary.Op.GreaterThan => $"({left} > {right})", TypedNodeExpressionBinary.Op.GreaterThan => $"({left} > {right})",
NodeExpressionBinary.Op.GreaterThanOrEqual => $"({left} >= {right})", TypedNodeExpressionBinary.Op.GreaterThanOrEqual => $"({left} >= {right})",
NodeExpressionBinary.Op.LeftShift => $"({left} << {right})", TypedNodeExpressionBinary.Op.LeftShift => $"({left} << {right})",
NodeExpressionBinary.Op.RightShift => $"({left} >> {right})", TypedNodeExpressionBinary.Op.RightShift => $"({left} >> {right})",
NodeExpressionBinary.Op.LogicalAnd => $"({left} && {right})", TypedNodeExpressionBinary.Op.LogicalAnd => $"({left} && {right})",
NodeExpressionBinary.Op.LogicalOr => $"({left} || {right})", TypedNodeExpressionBinary.Op.LogicalOr => $"({left} || {right})",
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
} }
private string EmitExpressionUnary(NodeExpressionUnary expression) private string EmitExpressionUnary(TypedNodeExpressionUnary expression)
{ {
var target = EmitExpression(expression.Target); var target = EmitExpression(expression.Target);
return expression.Operation switch return expression.Operation switch
{ {
NodeExpressionUnary.Op.Negate => $"(-{target})", TypedNodeExpressionUnary.Op.Negate => $"(-{target})",
NodeExpressionUnary.Op.Invert => $"(!{target})", TypedNodeExpressionUnary.Op.Invert => $"(!{target})",
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
} }
private string EmitExpressionStructLiteral(NodeExpressionStructLiteral expression) private string EmitExpressionStructLiteral(TypedNodeExpressionStructLiteral expression)
{ {
var initializerValues = new Dictionary<string, string>(); var initializerValues = new Dictionary<string, string>();
@@ -246,27 +254,27 @@ public sealed class Generator(List<NodeDefinition> nodes)
var initializerStrings = initializerValues.Select(x => $".{x.Key} = {x.Value}"); var initializerStrings = initializerValues.Select(x => $".{x.Key} = {x.Value}");
return $"(struct {expression.Name.Ident}){{ {string.Join(", ", initializerStrings)} }}"; return $"(struct {structTypeNames[(NubTypeStruct)expression.Type]}){{ {string.Join(", ", initializerStrings)} }}";
} }
private string EmitExpressionMemberAccess(NodeExpressionMemberAccess expression) private string EmitExpressionMemberAccess(TypedNodeExpressionMemberAccess expression)
{ {
var target = EmitExpression(expression.Target); var target = EmitExpression(expression.Target);
return $"{target}.{expression.Name.Ident}"; return $"{target}.{expression.Name.Ident}";
} }
private static string CType(NodeType node, string? varName = null) private string CType(NubType node, string? varName = null)
{ {
return node switch return node switch
{ {
NodeTypeVoid => "void" + (varName != null ? $" {varName}" : ""), NubTypeVoid => "void" + (varName != null ? $" {varName}" : ""),
NodeTypeBool => "bool" + (varName != null ? $" {varName}" : ""), NubTypeBool => "bool" + (varName != null ? $" {varName}" : ""),
NodeTypeCustom type => $"struct {type.Name.Ident}" + (varName != null ? $" {varName}" : ""), NubTypeStruct type => $"struct {structTypeNames[type]}" + (varName != null ? $" {varName}" : ""),
NodeTypeSInt type => $"int{type.Width}_t" + (varName != null ? $" {varName}" : ""), NubTypeSInt type => $"int{type.Width}_t" + (varName != null ? $" {varName}" : ""),
NodeTypeUInt type => $"uint{type.Width}_t" + (varName != null ? $" {varName}" : ""), NubTypeUInt type => $"uint{type.Width}_t" + (varName != null ? $" {varName}" : ""),
NodeTypePointer type => CType(type.To) + (varName != null ? $" *{varName}" : "*"), NubTypePointer type => CType(type.To) + (varName != null ? $" *{varName}" : "*"),
NodeTypeString => "struct string" + (varName != null ? $" {varName}" : ""), NubTypeString => "struct string" + (varName != null ? $" {varName}" : ""),
NodeTypeFunc type => $"{CType(type.ReturnType)} (*{varName})({string.Join(", ", type.Parameters.Select(p => CType(p)))})", NubTypeFunc type => $"{CType(type.ReturnType)} (*{varName})({string.Join(", ", type.Parameters.Select(p => CType(p)))})",
_ => throw new ArgumentOutOfRangeException(nameof(node), node, null) _ => throw new ArgumentOutOfRangeException(nameof(node), node, null)
}; };
} }

View File

@@ -1,8 +1,10 @@
using Compiler; using Compiler;
var file = File.ReadAllText("test.nub"); const string fileName = "test.nub";
var tokens = Tokenizer.Tokenize("test.nub", file, out var tokenizerDiagnostics); var file = File.ReadAllText(fileName);
var tokens = Tokenizer.Tokenize(fileName, file, out var tokenizerDiagnostics);
foreach (var diagnostic in tokenizerDiagnostics) foreach (var diagnostic in tokenizerDiagnostics)
{ {
@@ -14,7 +16,7 @@ if (tokenizerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error))
return 1; return 1;
} }
var nodes = Parser.Parse("test.nub", tokens, out var parserDiagnostics); var nodes = Parser.Parse(fileName, tokens, out var parserDiagnostics);
foreach (var diagnostic in parserDiagnostics) foreach (var diagnostic in parserDiagnostics)
{ {
@@ -26,7 +28,19 @@ if (parserDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error))
return 1; return 1;
} }
var output = Generator.Emit(nodes); var typedNodes = TypeChecker.Check(fileName, nodes, out var typeCheckerDiagnostics);
foreach (var diagnostic in typeCheckerDiagnostics)
{
DiagnosticFormatter.Print(diagnostic, Console.Error);
}
if (typeCheckerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error))
{
return 1;
}
var output = Generator.Emit(typedNodes);
File.WriteAllText("C:/Users/oliste/repos/nub-lang/compiler/Compiler/out.c", output); File.WriteAllText("C:/Users/oliste/repos/nub-lang/compiler/Compiler/out.c", output);

View File

@@ -0,0 +1,674 @@
namespace Compiler;
public sealed class TypeChecker(string fileName, List<NodeDefinition> definitions)
{
public static TypedAst Check(string fileName, List<NodeDefinition> nodes, out List<Diagnostic> diagnostics)
{
return new TypeChecker(fileName, nodes).Check(out diagnostics);
}
private Scope scope = new(null);
private Dictionary<string, NubTypeStruct> structTypes = new();
private TypedAst Check(out List<Diagnostic> diagnostics)
{
var functions = new List<TypedNodeDefinitionFunc>();
diagnostics = [];
// todo(nub31): Types must be resolved better to prevent circular dependencies and independent ordering
foreach (var structDef in definitions.OfType<NodeDefinitionStruct>())
{
var fields = structDef.Fields.Select(x => new NubTypeStruct.Field(x.Name.Ident, CheckType(x.Type))).ToList();
structTypes.Add(structDef.Name.Ident, new NubTypeStruct(fields));
}
foreach (var funcDef in definitions.OfType<NodeDefinitionFunc>())
{
var type = new NubTypeFunc(funcDef.Parameters.Select(x => CheckType(x.Type)).ToList(), CheckType(funcDef.ReturnType));
scope.DeclareIdentifier(funcDef.Name.Ident, type);
}
foreach (var funcDef in definitions.OfType<NodeDefinitionFunc>())
{
try
{
functions.Add(CheckDefinitionFunc(funcDef));
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
}
}
return new TypedAst(functions, structTypes.Values.ToList());
}
private TypedNodeDefinitionFunc CheckDefinitionFunc(NodeDefinitionFunc definition)
{
return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), CheckType(definition.ReturnType));
}
private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node)
{
return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, CheckType(node.Type));
}
private TypedNodeStatement CheckStatement(NodeStatement node)
{
return node switch
{
NodeStatementAssignment statement => CheckStatementAssignment(statement),
NodeStatementBlock statement => CheckStatementBlock(statement),
NodeStatementFuncCall statement => CheckStatementFuncCall(statement),
NodeStatementIf statement => CheckStatementIf(statement),
NodeStatementReturn statement => CheckStatementReturn(statement),
NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement),
NodeStatementWhile statement => CheckStatementWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement)
{
return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value));
}
private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement)
{
return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList());
}
private TypedNodeStatementFuncCall CheckStatementFuncCall(NodeStatementFuncCall statement)
{
return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(statement.Target), statement.Parameters.Select(CheckExpression).ToList());
}
private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement)
{
return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock));
}
private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement)
{
return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value));
}
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement)
{
var type = CheckType(statement.Type);
var value = CheckExpression(statement.Value);
if (type != value.Type)
throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build());
scope.DeclareIdentifier(statement.Name.Ident, type);
return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value);
}
private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement)
{
return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Block));
}
private TypedNodeExpression CheckExpression(NodeExpression node)
{
return node switch
{
NodeExpressionBinary expression => CheckExpressionBinary(expression),
NodeExpressionUnary expression => CheckExpressionUnary(expression),
NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression),
NodeExpressionIdent expression => CheckExpressionIdent(expression),
NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression),
NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression),
NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression),
NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression)
{
var left = CheckExpression(expression.Left);
var right = CheckExpression(expression.Right);
NubType type;
switch (expression.Operation)
{
case NodeExpressionBinary.Op.Add:
case NodeExpressionBinary.Op.Subtract:
case NodeExpressionBinary.Op.Multiply:
case NodeExpressionBinary.Op.Divide:
case NodeExpressionBinary.Op.Modulo:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side arithmetic operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side arithmetic operation: {right.Type}").At(fileName, right).Build());
type = left.Type;
break;
}
case NodeExpressionBinary.Op.LeftShift:
case NodeExpressionBinary.Op.RightShift:
{
if (left.Type is not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of left/right shift operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of left/right shift operation: {right.Type}").At(fileName, right).Build());
type = left.Type;
break;
}
case NodeExpressionBinary.Op.Equal:
case NodeExpressionBinary.Op.NotEqual:
case NodeExpressionBinary.Op.LessThan:
case NodeExpressionBinary.Op.LessThanOrEqual:
case NodeExpressionBinary.Op.GreaterThan:
case NodeExpressionBinary.Op.GreaterThanOrEqual:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of comparison: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build());
type = new NubTypeBool();
break;
}
case NodeExpressionBinary.Op.LogicalAnd:
case NodeExpressionBinary.Op.LogicalOr:
{
if (left.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of logical operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build());
type = new NubTypeBool();
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right);
}
private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op)
{
return op switch
{
NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add,
NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract,
NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply,
NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide,
NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo,
NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal,
NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual,
NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan,
NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual,
NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan,
NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual,
NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift,
NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift,
NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd,
NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression)
{
var target = CheckExpression(expression.Target);
NubType type;
switch (expression.Operation)
{
case NodeExpressionUnary.Op.Negate:
{
if (target.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for negation: {target.Type}").At(fileName, target).Build());
type = target.Type;
break;
}
case NodeExpressionUnary.Op.Invert:
{
if (target.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build());
type = new NubTypeBool();
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation));
}
private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op)
{
return op switch
{
NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate,
NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression)
{
return new TypedNodeExpressionBoolLiteral(expression.Tokens, new NubTypeBool(), expression.Value);
}
private TypedNodeExpressionIdent CheckExpressionIdent(NodeExpressionIdent expression)
{
var type = scope.GetIdentifierType(expression.Value.Ident);
if (type == null)
throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build());
return new TypedNodeExpressionIdent(expression.Tokens, type, expression.Value);
}
private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression)
{
return new TypedNodeExpressionIntLiteral(expression.Tokens, new NubTypeSInt(32), expression.Value);
}
private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression)
{
var target = CheckExpression(expression.Target);
if (target.Type is not NubTypeStruct structType)
throw new CompileException(Diagnostic.Error($"Cannot access member of non-struct type {target.Type}").At(fileName, target).Build());
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident);
if (field == null)
throw new CompileException(Diagnostic.Error($"Struct {target.Type} does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build());
return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name);
}
private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression)
{
return new TypedNodeExpressionStringLiteral(expression.Tokens, new NubTypeString(), expression.Value);
}
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression)
{
var type = structTypes.GetValueOrDefault(expression.Name.Ident);
if (type == null)
throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Name.Ident}'").At(fileName, expression.Name).Build());
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
foreach (var initializer in expression.Initializers)
{
var field = type.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident);
if (field == null)
throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build());
var value = CheckExpression(initializer.Value);
if (value.Type != field.Type)
throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build());
initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value));
}
return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers);
}
private NubType CheckType(NodeType node)
{
return node switch
{
NodeTypeBool type => new NubTypeBool(),
NodeTypeCustom type => CheckStructType(type),
NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)),
NodeTypePointer type => new NubTypePointer(CheckType(type.To)),
NodeTypeSInt type => new NubTypeSInt(type.Width),
NodeTypeUInt type => new NubTypeUInt(type.Width),
NodeTypeString type => new NubTypeString(),
NodeTypeVoid type => new NubTypeVoid(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private NubTypeStruct CheckStructType(NodeTypeCustom type)
{
var structType = structTypes.GetValueOrDefault(type.Name.Ident);
if (structType == null)
throw new CompileException(Diagnostic.Error($"Unknown custom type: {type}").At(fileName, type).Build());
return structType;
}
private class Scope(Scope? parent)
{
private Dictionary<string, NubType> identifiers = new();
public void DeclareIdentifier(string name, NubType type)
{
identifiers.Add(name, type);
}
public NubType? GetIdentifierType(string name)
{
return identifiers.TryGetValue(name, out var type)
? type
: parent?.GetIdentifierType(name);
}
}
}
public sealed class TypedAst(List<TypedNodeDefinitionFunc> functions, List<NubTypeStruct> structTypes)
{
public List<TypedNodeDefinitionFunc> Functions = functions;
public List<NubTypeStruct> StructTypes = structTypes;
}
public abstract class TypedNode(List<Token> tokens)
{
public readonly List<Token> Tokens = tokens;
}
public abstract class TypedNodeDefinition(List<Token> tokens) : TypedNode(tokens);
public sealed class TypedNodeDefinitionFunc(List<Token> tokens, TokenIdent name, List<TypedNodeDefinitionFunc.Param> parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens)
{
public readonly TokenIdent Name = name;
public readonly List<Param> Parameters = parameters;
public readonly TypedNodeStatement Body = body;
public readonly NubType ReturnType = returnType;
public sealed class Param(List<Token> tokens, TokenIdent name, NubType type) : TypedNode(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
}
}
public sealed class TypedNodeDefinitionStruct(List<Token> tokens, TokenIdent name, List<TypedNodeDefinitionStruct.Field> fields) : TypedNodeDefinition(tokens)
{
public readonly TokenIdent Name = name;
public readonly List<Field> Fields = fields;
public sealed class Field(List<Token> tokens, TokenIdent name, NubType type) : TypedNode(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
}
}
public abstract class TypedNodeStatement(List<Token> tokens) : TypedNode(tokens);
public sealed class TypedNodeStatementBlock(List<Token> tokens, List<TypedNodeStatement> statements) : TypedNodeStatement(tokens)
{
public readonly List<TypedNodeStatement> Statements = statements;
}
public sealed class TypedNodeStatementFuncCall(List<Token> tokens, TypedNodeExpression target, List<TypedNodeExpression> parameters) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Target = target;
public readonly List<TypedNodeExpression> Parameters = parameters;
}
public sealed class TypedNodeStatementReturn(List<Token> tokens, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementVariableDeclaration(List<Token> tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementAssignment(List<Token> tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Target = target;
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementIf(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Condition = condition;
public readonly TypedNodeStatement ThenBlock = thenBlock;
public readonly TypedNodeStatement? ElseBlock = elseBlock;
}
public sealed class TypedNodeStatementWhile(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement block) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Condition = condition;
public readonly TypedNodeStatement Block = block;
}
public abstract class TypedNodeExpression(List<Token> tokens, NubType type) : TypedNode(tokens)
{
public readonly NubType Type = type;
}
public sealed class TypedNodeExpressionIntLiteral(List<Token> tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenIntLiteral Value = value;
}
public sealed class TypedNodeExpressionStringLiteral(List<Token> tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenStringLiteral Value = value;
}
public sealed class TypedNodeExpressionBoolLiteral(List<Token> tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenBoolLiteral Value = value;
}
public sealed class TypedNodeExpressionStructLiteral(List<Token> tokens, NubType type, List<TypedNodeExpressionStructLiteral.Initializer> initializers) : TypedNodeExpression(tokens, type)
{
public readonly List<Initializer> Initializers = initializers;
public sealed class Initializer(List<Token> tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens)
{
public readonly TokenIdent Name = name;
public readonly TypedNodeExpression Value = value;
}
}
public sealed class TypedNodeExpressionMemberAccess(List<Token> tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type)
{
public readonly TypedNodeExpression Target = target;
public readonly TokenIdent Name = name;
}
public sealed class TypedNodeExpressionIdent(List<Token> tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type)
{
public readonly TokenIdent Value = value;
}
public sealed class TypedNodeExpressionBinary(List<Token> tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type)
{
public readonly TypedNodeExpression Left = left;
public readonly Op Operation = operation;
public readonly TypedNodeExpression Right = right;
public enum Op
{
Add,
Subtract,
Multiply,
Divide,
Modulo,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
LeftShift,
RightShift,
// BitwiseAnd,
// BitwiseXor,
// BitwiseOr,
LogicalAnd,
LogicalOr,
}
}
public sealed class TypedNodeExpressionUnary(List<Token> tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
public Op Operation { get; } = op;
public enum Op
{
Negate,
Invert,
}
}
public abstract class NubType : IEquatable<NubType>
{
public abstract override string ToString();
public abstract bool Equals(NubType? other);
public override bool Equals(object? obj)
{
if (obj is NubType otherNubType)
{
return Equals(otherNubType);
}
return false;
}
public abstract override int GetHashCode();
public static bool operator ==(NubType? left, NubType? right) => Equals(left, right);
public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right);
}
public sealed class NubTypeVoid : NubType
{
public override string ToString() => "void";
public override bool Equals(NubType? other) => other is NubTypeVoid;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid));
}
public sealed class NubTypeUInt(int width) : NubType
{
public readonly int Width = width;
public override string ToString() => $"u{Width}";
public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width);
}
public sealed class NubTypeSInt(int width) : NubType
{
public readonly int Width = width;
public override string ToString() => $"i{Width}";
public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width);
}
public sealed class NubTypeBool : NubType
{
public override string ToString() => "bool";
public override bool Equals(NubType? other) => other is NubTypeBool;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool));
}
public sealed class NubTypeString : NubType
{
public override string ToString() => "string";
public override bool Equals(NubType? other) => other is NubTypeString;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString));
}
public sealed class NubTypeStruct(List<NubTypeStruct.Field> fields) : NubType
{
public readonly List<Field> Fields = fields;
public override string ToString() => $"struct {{ {string.Join(' ', Fields.Select(x => $"{x.Name}: {x.Type}"))} }}";
public override bool Equals(NubType? other)
{
if (other is not NubTypeStruct structType)
return false;
if (Fields.Count != structType.Fields.Count)
return false;
for (var i = 0; i < Fields.Count; i++)
{
if (Fields[i].Name != structType.Fields[i].Name)
return false;
if (Fields[i].Type != structType.Fields[i].Type)
return false;
}
return true;
}
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(typeof(NubTypeStruct));
foreach (var field in Fields)
{
hash.Add(field.Name);
hash.Add(field.Type);
}
return hash.ToHashCode();
}
public sealed class Field(string name, NubType type)
{
public readonly string Name = name;
public readonly NubType Type = type;
}
}
public sealed class NubTypePointer(NubType to) : NubType
{
public readonly NubType To = to;
public override string ToString() => $"^{To}";
public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To;
public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer));
}
public sealed class NubTypeFunc(List<NubType> parameters, NubType returnType) : NubType
{
public readonly List<NubType> Parameters = parameters;
public readonly NubType ReturnType = returnType;
public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}";
public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(typeof(NubTypeFunc));
hash.Add(ReturnType);
foreach (var param in Parameters)
hash.Add(param);
return hash.ToHashCode();
}
}