Big type inference improvements

This commit is contained in:
nub31
2025-10-22 12:55:31 +02:00
parent e2da6cccff
commit 93cef598e8
8 changed files with 205 additions and 163 deletions

View File

@@ -7,10 +7,7 @@ using NubLang.Syntax;
var diagnostics = new List<Diagnostic>(); var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>(); var syntaxTrees = new List<SyntaxTree>();
var nubFiles = args.Where(x => Path.GetExtension(x) == ".nub").ToArray(); foreach (var file in args)
var objectFileArgs = args.Where(x => Path.GetExtension(x) is ".o" or ".a").ToArray();
foreach (var file in nubFiles)
{ {
var tokenizer = new Tokenizer(file, File.ReadAllText(file)); var tokenizer = new Tokenizer(file, File.ReadAllText(file));
tokenizer.Tokenize(); tokenizer.Tokenize();
@@ -26,7 +23,7 @@ foreach (var file in nubFiles)
var modules = Module.Collect(syntaxTrees); var modules = Module.Collect(syntaxTrees);
var compilationUnits = new List<CompilationUnit>(); var compilationUnits = new List<CompilationUnit>();
for (var i = 0; i < nubFiles.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var typeChecker = new TypeChecker(syntaxTrees[i], modules); var typeChecker = new TypeChecker(syntaxTrees[i], modules);
var compilationUnit = typeChecker.Check(); var compilationUnit = typeChecker.Check();
@@ -49,9 +46,9 @@ var cPaths = new List<string>();
Directory.CreateDirectory(".build"); Directory.CreateDirectory(".build");
for (var i = 0; i < nubFiles.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var file = nubFiles[i]; var file = args[i];
var compilationUnit = compilationUnits[i]; var compilationUnit = compilationUnits[i];
var generator = new Generator(compilationUnit); var generator = new Generator(compilationUnit);

View File

@@ -80,17 +80,29 @@ public abstract record LValueExpressionNode(List<Token> Tokens, NubType Type) :
public abstract record RValueExpressionNode(List<Token> Tokens, NubType Type) : ExpressionNode(Tokens, Type); public abstract record RValueExpressionNode(List<Token> Tokens, NubType Type) : ExpressionNode(Tokens, Type);
public record StringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type); public record StringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubStringType());
public record CStringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type); public record CStringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubCStringType());
public record IntLiteralNode(List<Token> Tokens, NubType Type, long Value) : RValueExpressionNode(Tokens, Type); public record I8LiteralNode(List<Token> Tokens, sbyte Value) : RValueExpressionNode(Tokens, new NubIntType(true, 8));
public record UIntLiteralNode(List<Token> Tokens, NubType Type, ulong Value) : RValueExpressionNode(Tokens, Type); public record I16LiteralNode(List<Token> Tokens, short Value) : RValueExpressionNode(Tokens, new NubIntType(true, 16));
public record Float32LiteralNode(List<Token> Tokens, NubType Type, float Value) : RValueExpressionNode(Tokens, Type); public record I32LiteralNode(List<Token> Tokens, int Value) : RValueExpressionNode(Tokens, new NubIntType(true, 32));
public record Float64LiteralNode(List<Token> Tokens, NubType Type, double Value) : RValueExpressionNode(Tokens, Type); public record I64LiteralNode(List<Token> Tokens, long Value) : RValueExpressionNode(Tokens, new NubIntType(true, 64));
public record U8LiteralNode(List<Token> Tokens, byte Value) : RValueExpressionNode(Tokens, new NubIntType(false, 8));
public record U16LiteralNode(List<Token> Tokens, ushort Value) : RValueExpressionNode(Tokens, new NubIntType(false, 16));
public record U32LiteralNode(List<Token> Tokens, uint Value) : RValueExpressionNode(Tokens, new NubIntType(false, 32));
public record U64LiteralNode(List<Token> Tokens, ulong Value) : RValueExpressionNode(Tokens, new NubIntType(false, 64));
public record Float32LiteralNode(List<Token> Tokens, float Value) : RValueExpressionNode(Tokens, new NubFloatType(32));
public record Float64LiteralNode(List<Token> Tokens, double Value) : RValueExpressionNode(Tokens, new NubFloatType(64));
public record BoolLiteralNode(List<Token> Tokens, NubType Type, bool Value) : RValueExpressionNode(Tokens, Type); public record BoolLiteralNode(List<Token> Tokens, NubType Type, bool Value) : RValueExpressionNode(Tokens, Type);
@@ -114,13 +126,13 @@ public record AddressOfNode(List<Token> Tokens, NubType Type, LValueExpressionNo
public record StructFieldAccessNode(List<Token> Tokens, NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Tokens, Type); public record StructFieldAccessNode(List<Token> Tokens, NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Tokens, Type);
public record StructInitializerNode(List<Token> Tokens, NubStructType StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, StructType); public record StructInitializerNode(List<Token> Tokens, NubType Type, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, Type);
public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode Target) : LValueExpressionNode(Tokens, Type); public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode Target) : LValueExpressionNode(Tokens, Type);
public record ConvertIntNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubIntType ValueType, NubIntType TargetType) : RValueExpressionNode(Tokens, Type); public record ConvertIntNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth, bool StartSignedness, bool TargetSignedness) : RValueExpressionNode(Tokens, new NubIntType(TargetSignedness, TargetWidth));
public record ConvertFloatNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubFloatType ValueType, NubFloatType TargetType) : RValueExpressionNode(Tokens, Type); public record ConvertFloatNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth) : RValueExpressionNode(Tokens, new NubFloatType(TargetWidth));
public record ConvertStringToCStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubCStringType()); public record ConvertStringToCStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubCStringType());

View File

@@ -10,7 +10,6 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _importedModules; private readonly Dictionary<string, Module> _importedModules;
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
@@ -29,7 +28,6 @@ public sealed class TypeChecker
public CompilationUnit Check() public CompilationUnit Check()
{ {
_scopes.Clear(); _scopes.Clear();
_funcReturnTypes.Clear();
_typeCache.Clear(); _typeCache.Clear();
_resolvingTypes.Clear(); _resolvingTypes.Clear();
@@ -117,26 +115,26 @@ public sealed class TypeChecker
BlockNode? body = null; BlockNode? body = null;
if (node.Body != null) if (node.Body != null)
{ {
_funcReturnTypes.Push(prototype.ReturnType); using (BeginScope())
body = CheckBlock(node.Body);
if (!AlwaysReturns(body))
{ {
if (prototype.ReturnType is NubVoidType) CurrentScope.SetReturnType(prototype.ReturnType);
body = CheckBlock(node.Body);
if (!AlwaysReturns(body))
{ {
body.Statements.Add(new ReturnNode(node.Tokens.Skip(node.Tokens.Count - 1).ToList(), null)); if (prototype.ReturnType is NubVoidType)
} {
else body.Statements.Add(new ReturnNode(node.Tokens.Skip(node.Tokens.Count - 1).ToList(), null));
{ }
Diagnostics.Add(Diagnostic else
.Error("Not all code paths return a value") {
.At(node.Body) Diagnostics.Add(Diagnostic
.Build()); .Error("Not all code paths return a value")
.At(node.Body)
.Build());
}
} }
} }
_funcReturnTypes.Pop();
} }
return new FuncNode(node.Tokens, prototype, body); return new FuncNode(node.Tokens, prototype, body);
@@ -151,12 +149,21 @@ public sealed class TypeChecker
} }
var value = CheckExpression(statement.Value, lValue.Type); var value = CheckExpression(statement.Value, lValue.Type);
if (value.Type != lValue.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {value.Type} to {lValue.Type}")
.At(statement.Value)
.Build());
}
return new AssignmentNode(statement.Tokens, lValue, value); return new AssignmentNode(statement.Tokens, lValue, value);
} }
private IfNode CheckIf(IfSyntax statement) private IfNode CheckIf(IfSyntax statement)
{ {
var condition = CheckExpression(statement.Condition, new NubBoolType()); var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body); var body = CheckBlock(statement.Body);
Variant<IfNode, BlockNode>? elseStatement = null; Variant<IfNode, BlockNode>? elseStatement = null;
if (statement.Else.HasValue) if (statement.Else.HasValue)
@@ -173,7 +180,8 @@ public sealed class TypeChecker
if (statement.Value != null) if (statement.Value != null)
{ {
value = CheckExpression(statement.Value, _funcReturnTypes.Peek()); var expectedReturnType = CurrentScope.GetReturnType();
value = CheckExpression(statement.Value, expectedReturnType);
} }
return new ReturnNode(statement.Tokens, value); return new ReturnNode(statement.Tokens, value);
@@ -203,12 +211,26 @@ public sealed class TypeChecker
if (statement.Assignment != null) if (statement.Assignment != null)
{ {
assignmentNode = CheckExpression(statement.Assignment, type); assignmentNode = CheckExpression(statement.Assignment, type);
type ??= assignmentNode.Type;
if (type == null)
{
type = assignmentNode.Type;
}
else if (assignmentNode.Type != type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
.At(statement.Assignment)
.Build());
}
} }
if (type == null) if (type == null)
{ {
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); throw new TypeCheckerException(Diagnostic
.Error($"Cannot infer type of variable {statement.Name}")
.At(statement)
.Build());
} }
CurrentScope.DeclareVariable(new Variable(statement.Name, type)); CurrentScope.DeclareVariable(new Variable(statement.Name, type));
@@ -218,7 +240,7 @@ public sealed class TypeChecker
private WhileNode CheckWhile(WhileSyntax statement) private WhileNode CheckWhile(WhileSyntax statement)
{ {
var condition = CheckExpression(statement.Condition, new NubBoolType()); var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body); var body = CheckBlock(statement.Body);
return new WhileNode(statement.Tokens, condition, body); return new WhileNode(statement.Tokens, condition, body);
} }
@@ -236,64 +258,32 @@ public sealed class TypeChecker
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
{ {
var result = node switch return node switch
{ {
AddressOfSyntax expression => CheckAddressOf(expression), AddressOfSyntax expression => CheckAddressOf(expression, expectedType),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression), ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression, expectedType),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression), ArrayInitializerSyntax expression => CheckArrayInitializer(expression, expectedType),
BinaryExpressionSyntax expression => CheckBinaryExpression(expression, expectedType), BinaryExpressionSyntax expression => CheckBinaryExpression(expression, expectedType),
UnaryExpressionSyntax expression => CheckUnaryExpression(expression, expectedType), UnaryExpressionSyntax expression => CheckUnaryExpression(expression, expectedType),
DereferenceSyntax expression => CheckDereference(expression), DereferenceSyntax expression => CheckDereference(expression, expectedType),
FuncCallSyntax expression => CheckFuncCall(expression), FuncCallSyntax expression => CheckFuncCall(expression, expectedType),
LocalIdentifierSyntax expression => CheckLocalIdentifier(expression), LocalIdentifierSyntax expression => CheckLocalIdentifier(expression, expectedType),
ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression), ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression, expectedType),
BoolLiteralSyntax expression => CheckBoolLiteral(expression), BoolLiteralSyntax expression => CheckBoolLiteral(expression, expectedType),
StringLiteralSyntax expression => CheckStringLiteral(expression, expectedType), StringLiteralSyntax expression => CheckStringLiteral(expression, expectedType),
IntLiteralSyntax expression => CheckIntLiteral(expression, expectedType), IntLiteralSyntax expression => CheckIntLiteral(expression, expectedType),
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
InterpretBuiltinSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) }, InterpretBuiltinSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) },
SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)), SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)),
FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression), FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
if (expectedType == null || result.Type == expectedType)
{
return result;
}
if (result.Type is NubStringType && expectedType is NubCStringType)
{
return new ConvertStringToCStringNode(node.Tokens, result);
}
if (result.Type is NubCStringType && expectedType is NubStringType)
{
return new ConvertCStringToStringNode(node.Tokens, result);
}
if (result.Type is NubIntType sourceIntType && expectedType is NubIntType targetIntType)
{
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
{
return new ConvertIntNode(node.Tokens, targetIntType, result, sourceIntType, targetIntType);
}
}
if (result.Type is NubFloatType sourceFloatType && expectedType is NubFloatType targetFloatType)
{
if (sourceFloatType.Width < targetFloatType.Width)
{
return new ConvertFloatNode(node.Tokens, targetFloatType, result, sourceFloatType, targetFloatType);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
} }
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression) // todo(nub31): Infer int type instead of explicit type syntax
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression, NubType? _)
{ {
var value = CheckExpression(expression.Value); var value = CheckExpression(expression.Value);
if (value.Type is not NubFloatType sourceFloatType) if (value.Type is not NubFloatType sourceFloatType)
@@ -316,9 +306,9 @@ public sealed class TypeChecker
return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType); return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType);
} }
private AddressOfNode CheckAddressOf(AddressOfSyntax expression) private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType);
if (target is not LValueExpressionNode lvalue) if (target is not LValueExpressionNode lvalue)
{ {
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
@@ -328,7 +318,7 @@ public sealed class TypeChecker
return new AddressOfNode(expression.Tokens, type, lvalue); return new AddressOfNode(expression.Tokens, type, lvalue);
} }
private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression, NubType? _)
{ {
var index = CheckExpression(expression.Index); var index = CheckExpression(expression.Index);
if (index.Type is not NubIntType) if (index.Type is not NubIntType)
@@ -349,7 +339,8 @@ public sealed class TypeChecker
}; };
} }
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression) // todo(nub31): Allow type inference instead of specifying type in syntax. Something like just []
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression, NubType? _)
{ {
var elementType = ResolveType(expression.ElementType); var elementType = ResolveType(expression.ElementType);
var type = new NubArrayType(elementType); var type = new NubArrayType(elementType);
@@ -387,7 +378,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Equal: case BinaryOperatorSyntax.Equal:
case BinaryOperatorSyntax.NotEqual: case BinaryOperatorSyntax.NotEqual:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType and not NubBoolType) if (left.Type is not NubIntType and not NubFloatType and not NubBoolType)
{ {
@@ -406,7 +396,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LessThan: case BinaryOperatorSyntax.LessThan:
case BinaryOperatorSyntax.LessThanOrEqual: case BinaryOperatorSyntax.LessThanOrEqual:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType) if (left.Type is not NubIntType and not NubFloatType)
{ {
@@ -423,7 +412,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LogicalAnd: case BinaryOperatorSyntax.LogicalAnd:
case BinaryOperatorSyntax.LogicalOr: case BinaryOperatorSyntax.LogicalOr:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubBoolType) if (left.Type is not NubBoolType)
{ {
@@ -440,16 +428,16 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Plus: case BinaryOperatorSyntax.Plus:
{ {
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is NubIntType or NubFloatType or NubStringType or NubCStringType) if (left.Type is not NubIntType and not NubFloatType and not NubStringType and not NubCStringType)
{ {
var right = CheckExpression(expression.Right, left.Type); throw new TypeCheckerException(Diagnostic
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); .Error("The plus operator must be used with int, float, cstring or string types")
.At(expression.Left)
.Build());
} }
throw new TypeCheckerException(Diagnostic var right = CheckExpression(expression.Right, left.Type);
.Error("The plus operator must be used with int, float, cstring or string types") return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
.At(expression.Left)
.Build());
} }
case BinaryOperatorSyntax.Minus: case BinaryOperatorSyntax.Minus:
case BinaryOperatorSyntax.Multiply: case BinaryOperatorSyntax.Multiply:
@@ -532,7 +520,7 @@ public sealed class TypeChecker
} }
} }
private DereferenceNode CheckDereference(DereferenceSyntax expression) private DereferenceNode CheckDereference(DereferenceSyntax expression, NubType? _)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
if (target.Type is not NubPointerType pointerType) if (target.Type is not NubPointerType pointerType)
@@ -543,7 +531,7 @@ public sealed class TypeChecker
return new DereferenceNode(expression.Tokens, pointerType.BaseType, target); return new DereferenceNode(expression.Tokens, pointerType.BaseType, target);
} }
private FuncCallNode CheckFuncCall(FuncCallSyntax expression) private FuncCallNode CheckFuncCall(FuncCallSyntax expression, NubType? _)
{ {
var accessor = CheckExpression(expression.Expression); var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not NubFuncType funcType) if (accessor.Type is not NubFuncType funcType)
@@ -563,13 +551,13 @@ public sealed class TypeChecker
for (var i = 0; i < expression.Parameters.Count; i++) for (var i = 0; i < expression.Parameters.Count; i++)
{ {
var parameter = expression.Parameters[i]; var parameter = expression.Parameters[i];
var expectedType = funcType.Parameters[i]; var expectedParameterType = funcType.Parameters[i];
var parameterExpression = CheckExpression(parameter, expectedType); var parameterExpression = CheckExpression(parameter, expectedParameterType);
if (parameterExpression.Type != expectedType) if (parameterExpression.Type != expectedParameterType)
{ {
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}") .Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
.At(parameter) .At(parameter)
.Build()); .Build());
} }
@@ -580,7 +568,7 @@ public sealed class TypeChecker
return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters); return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters);
} }
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{ {
var scopeIdent = CurrentScope.LookupVariable(expression.Name); var scopeIdent = CurrentScope.LookupVariable(expression.Name);
if (scopeIdent != null) if (scopeIdent != null)
@@ -601,7 +589,7 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build()); throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
} }
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression) private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{ {
if (!_importedModules.TryGetValue(expression.Module, out var module)) if (!_importedModules.TryGetValue(expression.Module, out var module))
{ {
@@ -630,51 +618,62 @@ public sealed class TypeChecker
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
{ {
return expectedType switch if (expectedType is NubCStringType)
{ {
NubCStringType => new CStringLiteralNode(expression.Tokens, expectedType, expression.Value), return new CStringLiteralNode(expression.Tokens, expression.Value);
NubStringType => new StringLiteralNode(expression.Tokens, expectedType, expression.Value), }
_ => new StringLiteralNode(expression.Tokens, new NubStringType(), expression.Value)
}; return new StringLiteralNode(expression.Tokens, expression.Value);
} }
private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType)
{ {
if (expectedType is NubIntType intType) if (expectedType is NubIntType intType)
{ {
return intType.Signed return intType.Width switch
? new IntLiteralNode(expression.Tokens, intType, Convert.ToInt64(expression.Value, expression.Base)) {
: new UIntLiteralNode(expression.Tokens, intType, Convert.ToUInt64(expression.Value, expression.Base)); 8 => intType.Signed ? new I8LiteralNode(expression.Tokens, Convert.ToSByte(expression.Value, expression.Base)) : new U8LiteralNode(expression.Tokens, Convert.ToByte(expression.Value, expression.Base)),
16 => intType.Signed ? new I16LiteralNode(expression.Tokens, Convert.ToInt16(expression.Value, expression.Base)) : new U16LiteralNode(expression.Tokens, Convert.ToUInt16(expression.Value, expression.Base)),
32 => intType.Signed ? new I32LiteralNode(expression.Tokens, Convert.ToInt32(expression.Value, expression.Base)) : new U32LiteralNode(expression.Tokens, Convert.ToUInt32(expression.Value, expression.Base)),
64 => intType.Signed ? new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base)) : new U64LiteralNode(expression.Tokens, Convert.ToUInt64(expression.Value, expression.Base)),
_ => throw new ArgumentOutOfRangeException()
};
} }
if (expectedType is NubFloatType floatType) if (expectedType is NubFloatType floatType)
{ {
return floatType.Width switch return floatType.Width switch
{ {
32 => new Float32LiteralNode(expression.Tokens, floatType, Convert.ToSingle(Convert.ToInt64(expression.Value, expression.Base))), 32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, floatType, Convert.ToDouble(Convert.ToInt64(expression.Value, expression.Base))), 64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
} }
var type = new NubIntType(true, 64); return new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base));
return new IntLiteralNode(expression.Tokens, type, Convert.ToInt64(expression.Value, expression.Base));
} }
private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType)
{ {
var type = expectedType as NubFloatType ?? new NubFloatType(64); if (expectedType is NubFloatType floatType)
return type.Width == 32 {
? new Float32LiteralNode(expression.Tokens, type, Convert.ToSingle(expression.Value)) return floatType.Width switch
: new Float64LiteralNode(expression.Tokens, type, Convert.ToDouble(expression.Value)); {
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException()
};
}
return new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value));
} }
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression) private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression, NubType? _)
{ {
return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value); return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value);
} }
private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression) private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression, NubType? _)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
switch (target.Type) switch (target.Type)
@@ -897,6 +896,7 @@ public record Variable(string Name, NubType Type);
public class Scope(string module, Scope? parent = null) public class Scope(string module, Scope? parent = null)
{ {
private NubType? _returnType;
private readonly List<Variable> _variables = []; private readonly List<Variable> _variables = [];
public string Module { get; } = module; public string Module { get; } = module;
@@ -905,6 +905,16 @@ public class Scope(string module, Scope? parent = null)
_variables.Add(variable); _variables.Add(variable);
} }
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(string name) public Variable? LookupVariable(string name)
{ {
var variable = _variables.FirstOrDefault(x => x.Name == name); var variable = _variables.FirstOrDefault(x => x.Name == name);

View File

@@ -14,12 +14,12 @@ public static class CType
NubFloatType floatType => CreateFloatType(floatType, variableName), NubFloatType floatType => CreateFloatType(floatType, variableName),
NubCStringType => "char*" + (variableName != null ? $" {variableName}" : ""), NubCStringType => "char*" + (variableName != null ? $" {variableName}" : ""),
NubPointerType ptr => CreatePointerType(ptr, variableName), NubPointerType ptr => CreatePointerType(ptr, variableName),
NubSliceType nubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""), NubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""),
NubStringType nubStringType => "string" + (variableName != null ? $" {variableName}" : ""), NubStringType => "string" + (variableName != null ? $" {variableName}" : ""),
NubConstArrayType arr => CreateConstArrayType(arr, variableName), NubConstArrayType arr => CreateConstArrayType(arr, variableName),
NubArrayType arr => CreateArrayType(arr, variableName), NubArrayType arr => CreateArrayType(arr, variableName),
NubFuncType fn => CreateFuncType(fn, variableName), NubFuncType fn => CreateFuncType(fn, variableName),
NubStructType st => $"{st.Name}" + (variableName != null ? $" {variableName}" : ""), NubStructType st => $"{st.Module}_{st.Name}" + (variableName != null ? $" {variableName}" : ""),
_ => throw new NotSupportedException($"C type generation not supported for: {type}") _ => throw new NotSupportedException($"C type generation not supported for: {type}")
}; };
} }

View File

@@ -35,12 +35,10 @@ public class Generator
public string Emit() public string Emit()
{ {
_writer.WriteLine("#include <stdint.h>");
_writer.WriteLine("#include <stdarg.h>");
_writer.WriteLine("#include <stddef.h>");
_writer.WriteLine();
_writer.WriteLine(""" _writer.WriteLine("""
#include <stdint.h>
#include <stddef.h>
typedef struct typedef struct
{ {
size_t length; size_t length;
@@ -274,14 +272,20 @@ public class Generator
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode),
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol), FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol),
IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode),
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})", SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})",
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode), SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode), StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode), I8LiteralNode i8LiteralNode => EmitI8Literal(i8LiteralNode),
I16LiteralNode i16LiteralNode => EmitI16Literal(i16LiteralNode),
I32LiteralNode i32LiteralNode => EmitI32Literal(i32LiteralNode),
I64LiteralNode i64LiteralNode => EmitI64Literal(i64LiteralNode),
U8LiteralNode u8LiteralNode => EmitU8Literal(u8LiteralNode),
U16LiteralNode u16LiteralNode => EmitU16Literal(u16LiteralNode),
U32LiteralNode u32LiteralNode => EmitU32Literal(u32LiteralNode),
U64LiteralNode u64LiteralNode => EmitU64Literal(u64LiteralNode),
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode), UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode),
VariableIdentifierNode variableIdentifierNode => variableIdentifierNode.Name, VariableIdentifierNode variableIdentifierNode => variableIdentifierNode.Name,
_ => throw new ArgumentOutOfRangeException(nameof(expressionNode)) _ => throw new ArgumentOutOfRangeException(nameof(expressionNode))
@@ -345,7 +349,7 @@ public class Generator
private string EmitConvertFloat(ConvertFloatNode convertFloatNode) private string EmitConvertFloat(ConvertFloatNode convertFloatNode)
{ {
var value = EmitExpression(convertFloatNode.Value); var value = EmitExpression(convertFloatNode.Value);
var targetCast = convertFloatNode.TargetType.Width switch var targetCast = convertFloatNode.TargetWidth switch
{ {
32 => "f32", 32 => "f32",
64 => "f64", 64 => "f64",
@@ -358,12 +362,12 @@ public class Generator
private string EmitConvertInt(ConvertIntNode convertIntNode) private string EmitConvertInt(ConvertIntNode convertIntNode)
{ {
var value = EmitExpression(convertIntNode.Value); var value = EmitExpression(convertIntNode.Value);
var targetType = convertIntNode.TargetType.Width switch var targetType = convertIntNode.TargetWidth switch
{ {
8 => convertIntNode.TargetType.Signed ? "int8_t" : "uint8_t", 8 => convertIntNode.TargetSignedness ? "int8_t" : "uint8_t",
16 => convertIntNode.TargetType.Signed ? "int16_t" : "uint16_t", 16 => convertIntNode.TargetSignedness ? "int16_t" : "uint16_t",
32 => convertIntNode.TargetType.Signed ? "int32_t" : "uint32_t", 32 => convertIntNode.TargetSignedness ? "int32_t" : "uint32_t",
64 => convertIntNode.TargetType.Signed ? "int64_t" : "uint64_t", 64 => convertIntNode.TargetSignedness ? "int64_t" : "uint64_t",
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
return $"({targetType}){value}"; return $"({targetType}){value}";
@@ -420,21 +424,10 @@ public class Generator
private string EmitFuncCall(FuncCallNode funcCallNode) private string EmitFuncCall(FuncCallNode funcCallNode)
{ {
var name = EmitExpression(funcCallNode.Expression); var name = EmitExpression(funcCallNode.Expression);
var parameterNames = funcCallNode.Parameters.Select(x => EmitExpression(x)).ToList(); var parameterNames = funcCallNode.Parameters.Select(EmitExpression).ToList();
return $"{name}({string.Join(", ", parameterNames)})"; return $"{name}({string.Join(", ", parameterNames)})";
} }
private string EmitIntLiteral(IntLiteralNode intLiteralNode)
{
var type = (NubIntType)intLiteralNode.Type;
return type.Width switch
{
8 or 16 or 32 => intLiteralNode.Value.ToString(),
64 => intLiteralNode.Value + "LL",
_ => throw new ArgumentOutOfRangeException()
};
}
private string EmitAddressOf(AddressOfNode addressOfNode) private string EmitAddressOf(AddressOfNode addressOfNode)
{ {
var value = EmitExpression(addressOfNode.LValue); var value = EmitExpression(addressOfNode.LValue);
@@ -477,15 +470,44 @@ public class Generator
return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}"; return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}";
} }
private string EmitUIntLiteral(UIntLiteralNode uIntLiteralNode) private string EmitI8Literal(I8LiteralNode i8LiteralNode)
{ {
var type = (NubIntType)uIntLiteralNode.Type; return i8LiteralNode.Value.ToString();
return type.Width switch }
{
8 or 16 or 32 => uIntLiteralNode.Value + "U", private string EmitI16Literal(I16LiteralNode i16LiteralNode)
64 => uIntLiteralNode.Value + "ULL", {
_ => throw new ArgumentOutOfRangeException() return i16LiteralNode.Value.ToString();
}; }
private string EmitI32Literal(I32LiteralNode i32LiteralNode)
{
return i32LiteralNode.Value.ToString();
}
private string EmitI64Literal(I64LiteralNode i64LiteralNode)
{
return i64LiteralNode.Value + "LL";
}
private string EmitU8Literal(U8LiteralNode u8LiteralNode)
{
return u8LiteralNode.Value.ToString();
}
private string EmitU16Literal(U16LiteralNode u16LiteralNode)
{
return u16LiteralNode.Value.ToString();
}
private string EmitU32Literal(U32LiteralNode u32LiteralNode)
{
return u32LiteralNode.Value.ToString();
}
private string EmitU64Literal(U64LiteralNode u64LiteralNode)
{
return u64LiteralNode.Value + "ULL";
} }
private string EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode) private string EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode)

View File

@@ -5,5 +5,6 @@ extern "puts" func puts(text: cstring)
extern "main" func main(argc: i64, argv: [?]cstring): i64 extern "main" func main(argc: i64, argv: [?]cstring): i64
{ {
let x = [23]i32 let x = [23]i32
puts("test")
return x[0] return x[0]
} }

View File

@@ -1,5 +1,5 @@
.build/out: main.nub .build/out: main.nub
nubc main.nub clang $$(nubc main.nub) -o .build/out
clean: clean:
@rm -r .build 2>/dev/null || true @rm -r .build 2>/dev/null || true

View File

@@ -1,5 +1,5 @@
.build/out: main.nub .build/out: main.nub
nubc main.nub clang $$(nubc main.nub) -o .build/out
clean: clean:
@rm -r .build 2>/dev/null || true @rm -r .build 2>/dev/null || true