This commit is contained in:
nub31
2025-09-20 18:17:40 +02:00
parent 4533f69683
commit 6c56404f1c
10 changed files with 292 additions and 85 deletions

View File

@@ -693,6 +693,7 @@ public class QBEGenerator
StructFuncCallNode expr => EmitStructFuncCall(expr), StructFuncCallNode expr => EmitStructFuncCall(expr),
StructInitializerNode expr => EmitStructInitializer(expr), StructInitializerNode expr => EmitStructInitializer(expr),
UnaryExpressionNode expr => EmitUnaryExpression(expr), UnaryExpressionNode expr => EmitUnaryExpression(expr),
SizeCompilerMacroNode expr => $"{SizeOf(expr.TargetType)}",
_ => throw new ArgumentOutOfRangeException(nameof(rValue)) _ => throw new ArgumentOutOfRangeException(nameof(rValue))
}; };
} }

View File

@@ -42,6 +42,11 @@ public class ModuleRepository
module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions); module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions);
break; break;
} }
case StructTemplateSyntax structDef:
{
// todo(nub31): Include templates in modules
break;
}
default: default:
{ {
throw new ArgumentOutOfRangeException(nameof(definition)); throw new ArgumentOutOfRangeException(nameof(definition));

View File

@@ -163,10 +163,20 @@ public sealed class Parser
return new FuncSyntax(GetTokens(startIndex), name.Value, exported, externSymbol, signature, body); return new FuncSyntax(GetTokens(startIndex), name.Value, exported, externSymbol, signature, body);
} }
private StructSyntax ParseStruct(int startIndex, bool exported) private DefinitionSyntax ParseStruct(int startIndex, bool exported)
{ {
var name = ExpectIdentifier(); var name = ExpectIdentifier();
var templateArguments = new List<string>();
if (TryExpectSymbol(Symbol.LessThan))
{
while (!TryExpectSymbol(Symbol.GreaterThan))
{
templateArguments.Add(ExpectIdentifier().Value);
TryExpectSymbol(Symbol.Comma);
}
}
ExpectSymbol(Symbol.OpenBrace); ExpectSymbol(Symbol.OpenBrace);
List<StructFieldSyntax> fields = []; List<StructFieldSyntax> fields = [];
@@ -207,6 +217,11 @@ public sealed class Parser
} }
} }
if (templateArguments.Count > 0)
{
return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs);
}
return new StructSyntax(GetTokens(startIndex), name.Value, exported, fields, funcs); return new StructSyntax(GetTokens(startIndex), name.Value, exported, fields, funcs);
} }
@@ -462,6 +477,7 @@ public sealed class Parser
Symbol.OpenBracket => ParseArrayInitializer(startIndex), Symbol.OpenBracket => ParseArrayInitializer(startIndex),
Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), Optional<TypeSyntax>.Empty(), ParseStructInitializerBody()), Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), Optional<TypeSyntax>.Empty(), ParseStructInitializerBody()),
Symbol.Struct => ParseStructInitializer(startIndex), Symbol.Struct => ParseStructInitializer(startIndex),
Symbol.At => ParseCompilerMacro(startIndex),
_ => throw new ParseException(Diagnostic _ => throw new ParseException(Diagnostic
.Error($"Unexpected symbol '{symbolToken.Symbol}' in expression") .Error($"Unexpected symbol '{symbolToken.Symbol}' in expression")
.WithHelp("Expected '(', '-', '!', '[' or '{'") .WithHelp("Expected '(', '-', '!', '[' or '{'")
@@ -478,6 +494,34 @@ public sealed class Parser
return ParsePostfixOperators(expr); return ParsePostfixOperators(expr);
} }
private ExpressionSyntax ParseCompilerMacro(int startIndex)
{
var name = ExpectIdentifier();
ExpectSymbol(Symbol.OpenParen);
switch (name.Value)
{
case "size":
{
var type = ParseType();
ExpectSymbol(Symbol.CloseParen);
return new SizeCompilerMacroSyntax(GetTokens(startIndex), type);
}
case "interpret":
{
var type = ParseType();
ExpectSymbol(Symbol.Comma);
var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen);
return new InterpretCompilerMacroSyntax(GetTokens(startIndex), type, expression);
}
default:
{
throw new ParseException(Diagnostic.Error("Unknown compiler macro").At(name).Build());
}
}
}
private ExpressionSyntax ParseIdentifier(int startIndex, IdentifierToken identifier) private ExpressionSyntax ParseIdentifier(int startIndex, IdentifierToken identifier)
{ {
if (TryExpectSymbol(Symbol.DoubleColon)) if (TryExpectSymbol(Symbol.DoubleColon))
@@ -699,13 +743,31 @@ public sealed class Parser
return new BoolTypeSyntax(GetTokens(startIndex)); return new BoolTypeSyntax(GetTokens(startIndex));
default: default:
{ {
var module = _moduleName;
if (TryExpectSymbol(Symbol.DoubleColon)) if (TryExpectSymbol(Symbol.DoubleColon))
{ {
var customTypeName = ExpectIdentifier().Value; var customTypeName = ExpectIdentifier();
return new CustomTypeSyntax(GetTokens(startIndex), name.Value, customTypeName); module = name.Value;
name = customTypeName;
} }
return new CustomTypeSyntax(GetTokens(startIndex), _moduleName, name.Value); var templateParameters = new List<TypeSyntax>();
if (TryExpectSymbol(Symbol.LessThan))
{
while (!TryExpectSymbol(Symbol.GreaterThan))
{
templateParameters.Add(ParseType());
TryExpectSymbol(Symbol.Comma);
}
}
if (templateParameters.Count > 0)
{
return new TemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value);
}
return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value);
} }
} }
} }

View File

@@ -15,3 +15,5 @@ public record StructFieldSyntax(IEnumerable<Token> Tokens, string Name, TypeSynt
public record StructFuncSyntax(IEnumerable<Token> Tokens, string Name, string? Hook, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); public record StructFuncSyntax(IEnumerable<Token> Tokens, string Name, string? Hook, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens);
public record StructSyntax(IEnumerable<Token> Tokens, string Name, bool Exported, List<StructFieldSyntax> Fields, List<StructFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name, Exported); public record StructSyntax(IEnumerable<Token> Tokens, string Name, bool Exported, List<StructFieldSyntax> Fields, List<StructFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name, Exported);
public record StructTemplateSyntax(IEnumerable<Token> Tokens, List<string> TemplateArguments, string Name, bool Exported, List<StructFieldSyntax> Fields, List<StructFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name, Exported);

View File

@@ -57,3 +57,7 @@ public record StructFieldAccessSyntax(IEnumerable<Token> Tokens, ExpressionSynta
public record StructInitializerSyntax(IEnumerable<Token> Tokens, Optional<TypeSyntax> StructType, Dictionary<string, ExpressionSyntax> Initializers) : ExpressionSyntax(Tokens); public record StructInitializerSyntax(IEnumerable<Token> Tokens, Optional<TypeSyntax> StructType, Dictionary<string, ExpressionSyntax> Initializers) : ExpressionSyntax(Tokens);
public record DereferenceSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); public record DereferenceSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens);
public record SizeCompilerMacroSyntax(IEnumerable<Token> Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens);
public record InterpretCompilerMacroSyntax(IEnumerable<Token> Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens);

View File

@@ -23,3 +23,5 @@ public record CStringTypeSyntax(IEnumerable<Token> Tokens) : TypeSyntax(Tokens);
public record ArrayTypeSyntax(IEnumerable<Token> Tokens, TypeSyntax BaseType) : TypeSyntax(Tokens); public record ArrayTypeSyntax(IEnumerable<Token> Tokens, TypeSyntax BaseType) : TypeSyntax(Tokens);
public record CustomTypeSyntax(IEnumerable<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens); public record CustomTypeSyntax(IEnumerable<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens);
public record TemplateTypeSyntax(IEnumerable<Token> Tokens, List<TypeSyntax> TemplateParameters, string Module, string Name) : TypeSyntax(Tokens);

View File

@@ -77,3 +77,5 @@ public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : LValue
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type);
public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type); public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type);
public record SizeCompilerMacroNode(TypeNode Type, TypeNode TargetType) : RValueExpressionNode(Type);

View File

@@ -1,4 +1,7 @@
namespace NubLang.TypeChecking.Node; using System.Security.Cryptography;
using System.Text;
namespace NubLang.TypeChecking.Node;
public abstract class TypeNode : IEquatable<TypeNode> public abstract class TypeNode : IEquatable<TypeNode>
{ {
@@ -157,3 +160,34 @@ public class StringTypeNode : TypeNode
public override bool Equals(TypeNode? other) => other is StringTypeNode; public override bool Equals(TypeNode? other) => other is StringTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode)); public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
} }
public static class NameMangler
{
public static string Mangle(params IEnumerable<TypeNode> types)
{
var readable = string.Join("_", types.Select(EncodeType));
return ComputeShortHash(readable);
}
private static string EncodeType(TypeNode node) => node switch
{
VoidTypeNode => "V",
BoolTypeNode => "B",
IntTypeNode i => (i.Signed ? "I" : "U") + i.Width,
FloatTypeNode f => "F" + f.Width,
CStringTypeNode => "CS",
StringTypeNode => "S",
PointerTypeNode p => "P" + EncodeType(p.BaseType),
ArrayTypeNode a => "A" + EncodeType(a.ElementType),
FuncTypeNode fn => "FN(" + string.Join(",", fn.Parameters.Select(EncodeType)) + ")" + EncodeType(fn.ReturnType),
StructTypeNode st => "ST(" + st.Module + "." + st.Name + ")",
_ => throw new NotSupportedException($"Cannot encode type: {node}")
};
private static string ComputeShortHash(string input)
{
var bytes = Encoding.UTF8.GetBytes(input);
var hash = SHA256.HashData(bytes);
return Convert.ToHexString(hash[..8]).ToLower();
}
}

View File

@@ -1,5 +1,6 @@
using System.Diagnostics; using System.Diagnostics;
using System.Globalization; using System.Globalization;
using System.Security.Cryptography;
using NubLang.Diagnostics; using NubLang.Diagnostics;
using NubLang.Modules; using NubLang.Modules;
using NubLang.Parsing.Syntax; using NubLang.Parsing.Syntax;
@@ -15,6 +16,8 @@ public sealed class TypeChecker
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private readonly Stack<TypeNode> _funcReturnTypes = []; private readonly Stack<TypeNode> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), TypeNode> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
private Scope Scope => _scopes.Peek(); private Scope Scope => _scopes.Peek();
@@ -38,12 +41,26 @@ public sealed class TypeChecker
Diagnostics.Clear(); Diagnostics.Clear();
Definitions.Clear(); Definitions.Clear();
ReferencedStructTypes.Clear(); ReferencedStructTypes.Clear();
_typeCache.Clear();
_resolvingTypes.Clear();
foreach (var definition in _syntaxTree.Definitions) foreach (var definition in _syntaxTree.Definitions)
{ {
try try
{ {
Definitions.Add(CheckDefinition(definition)); switch (definition)
{
case FuncSyntax funcSyntax:
Definitions.Add(CheckFuncDefinition(funcSyntax));
break;
case StructSyntax structSyntax:
Definitions.Add(CheckStructDefinition(structSyntax));
break;
case StructTemplateSyntax:
break;
default:
throw new ArgumentOutOfRangeException(nameof(definition));
}
} }
catch (TypeCheckerException e) catch (TypeCheckerException e)
{ {
@@ -52,16 +69,6 @@ public sealed class TypeChecker
} }
} }
private DefinitionNode CheckDefinition(DefinitionSyntax node)
{
return node switch
{
FuncSyntax definition => CheckFuncDefinition(definition),
StructSyntax definition => CheckStructDefinition(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private StructNode CheckStructDefinition(StructSyntax node) private StructNode CheckStructDefinition(StructSyntax node)
{ {
var fieldTypes = node.Fields var fieldTypes = node.Fields
@@ -77,9 +84,29 @@ public sealed class TypeChecker
} }
var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes);
var fields = node.Fields.Select(CheckStructField).ToList();
var functions = node.Functions.Select(x => CheckStructFunc(type, x)).ToList();
var fields = new List<StructFieldNode>(); return new StructNode(type, _syntaxTree.Metadata.ModuleName, node.Name, fields, functions);
foreach (var field in node.Fields) }
private StructFuncNode CheckStructFunc(StructTypeNode type, StructFuncSyntax function, Scope? scope = null)
{
scope ??= new Scope();
scope.DeclareVariable(new Variable("this", type, VariableKind.FunctionParameter));
foreach (var parameter in function.Signature.Parameters)
{
scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter));
}
_funcReturnTypes.Push(ResolveType(function.Signature.ReturnType));
var body = CheckBlock(function.Body, scope);
_funcReturnTypes.Pop();
return new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body);
}
private StructFieldNode CheckStructField(StructFieldSyntax field)
{ {
var value = Optional.Empty<ExpressionNode>(); var value = Optional.Empty<ExpressionNode>();
if (field.Value.HasValue) if (field.Value.HasValue)
@@ -87,27 +114,7 @@ public sealed class TypeChecker
value = CheckExpression(field.Value.Value, ResolveType(field.Type)); value = CheckExpression(field.Value.Value, ResolveType(field.Type));
} }
fields.Add(new StructFieldNode(field.Name, ResolveType(field.Type), value)); return new StructFieldNode(field.Name, ResolveType(field.Type), value);
}
var functions = new List<StructFuncNode>();
foreach (var function in node.Functions)
{
var scope = new Scope();
scope.Declare(new Identifier("this", type, IdentifierKind.FunctionParameter));
foreach (var parameter in function.Signature.Parameters)
{
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
}
_funcReturnTypes.Push(ResolveType(function.Signature.ReturnType));
var body = CheckBlock(function.Body, scope);
_funcReturnTypes.Pop();
functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body));
}
return new StructNode(type, _syntaxTree.Metadata.ModuleName, node.Name, fields, functions);
} }
private FuncNode CheckFuncDefinition(FuncSyntax node) private FuncNode CheckFuncDefinition(FuncSyntax node)
@@ -115,7 +122,7 @@ public sealed class TypeChecker
var scope = new Scope(); var scope = new Scope();
foreach (var parameter in node.Signature.Parameters) foreach (var parameter in node.Signature.Parameters)
{ {
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter));
} }
var signature = CheckFuncSignature(node.Signature); var signature = CheckFuncSignature(node.Signature);
@@ -236,7 +243,7 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build());
} }
Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable)); Scope.DeclareVariable(new Variable(statement.Name, type, VariableKind.Variable));
return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type);
} }
@@ -262,10 +269,10 @@ public sealed class TypeChecker
case ArrayTypeNode arrayType: case ArrayTypeNode arrayType:
{ {
var scope = Scope.SubScope(); var scope = Scope.SubScope();
scope.Declare(new Identifier(statement.ElementIdent, arrayType.ElementType, IdentifierKind.FunctionParameter)); scope.DeclareVariable(new Variable(statement.ElementIdent, arrayType.ElementType, VariableKind.FunctionParameter));
if (statement.IndexIdent != null) if (statement.IndexIdent != null)
{ {
scope.Declare(new Identifier(statement.ElementIdent, new IntTypeNode(true, 64), IdentifierKind.FunctionParameter)); scope.DeclareVariable(new Variable(statement.ElementIdent, new IntTypeNode(true, 64), VariableKind.FunctionParameter));
} }
var body = CheckBlock(statement.Body, scope); var body = CheckBlock(statement.Body, scope);
@@ -273,7 +280,7 @@ public sealed class TypeChecker
} }
default: default:
{ {
throw new TypeCheckerException(Diagnostic.Error($"Type {target.Type} is not an iterable target").Build()); throw new TypeCheckerException(Diagnostic.Error($"Type {target.Type} is not an iterable target").At(statement.Target).Build());
} }
} }
} }
@@ -306,6 +313,8 @@ public sealed class TypeChecker
LiteralSyntax expression => CheckLiteral(expression, expectedType), LiteralSyntax expression => CheckLiteral(expression, expectedType),
StructFieldAccessSyntax expression => CheckStructFieldAccess(expression), StructFieldAccessSyntax expression => CheckStructFieldAccess(expression),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
InterpretCompilerMacroSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) },
SizeCompilerMacroSyntax expression => new SizeCompilerMacroNode(new IntTypeNode(false, 64), ResolveType(expression.Type)),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
@@ -338,7 +347,7 @@ public sealed class TypeChecker
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
if (target is not LValueExpressionNode lvalue) if (target is not LValueExpressionNode lvalue)
{ {
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").Build()); throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
} }
var type = new PointerTypeNode(target.Type); var type = new PointerTypeNode(target.Type);
@@ -469,7 +478,7 @@ public sealed class TypeChecker
var operand = CheckExpression(expression.Operand); var operand = CheckExpression(expression.Operand);
if (operand.Type is not IntTypeNode { Signed: false } or FloatTypeNode) if (operand.Type is not IntTypeNode { Signed: false } or FloatTypeNode)
{ {
throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").Build()); throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").At(expression).Build());
} }
return new UnaryExpressionNode(operand.Type, UnaryOperator.Negate, operand); return new UnaryExpressionNode(operand.Type, UnaryOperator.Negate, operand);
@@ -479,7 +488,7 @@ public sealed class TypeChecker
var operand = CheckExpression(expression.Operand); var operand = CheckExpression(expression.Operand);
if (operand.Type is not BoolTypeNode) if (operand.Type is not BoolTypeNode)
{ {
throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").Build()); throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").At(expression).Build());
} }
return new UnaryExpressionNode(operand.Type, UnaryOperator.Invert, operand); return new UnaryExpressionNode(operand.Type, UnaryOperator.Invert, operand);
@@ -580,22 +589,25 @@ public sealed class TypeChecker
return new StructFuncCallNode(function.Type.ReturnType, expression.Name, structType, target, parameters); return new StructFuncCallNode(function.Type.ReturnType, expression.Name, structType, target, parameters);
} }
throw new TypeCheckerException(Diagnostic.Error($"No function {expression.Name} exists on type {target.Type}").Build()); throw new TypeCheckerException(Diagnostic
.Error($"No function {expression.Name} exists on type {target.Type}")
.At(expression)
.Build());
} }
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression)
{ {
// First, look in the current scope for a matching identifier // First, look in the current scope for a matching identifier
var scopeIdent = Scope.Lookup(expression.Name); var scopeIdent = Scope.LookupVariable(expression.Name);
if (scopeIdent != null) if (scopeIdent != null)
{ {
switch (scopeIdent.Kind) switch (scopeIdent.Kind)
{ {
case IdentifierKind.Variable: case VariableKind.Variable:
{ {
return new VariableIdentifierNode(scopeIdent.Type, expression.Name); return new VariableIdentifierNode(scopeIdent.Type, expression.Name);
} }
case IdentifierKind.FunctionParameter: case VariableKind.FunctionParameter:
{ {
return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name); return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name);
} }
@@ -624,7 +636,11 @@ public sealed class TypeChecker
{ {
if (!_visibleModules.TryGetValue(expression.Module, out var module)) if (!_visibleModules.TryGetValue(expression.Module, out var module))
{ {
throw new TypeCheckerException(Diagnostic.Error($"Module {expression.Module} not found").WithHelp($"import \"{expression.Module}\"").At(expression).Build()); throw new TypeCheckerException(Diagnostic
.Error($"Module {expression.Module} not found")
.WithHelp($"import \"{expression.Module}\"")
.At(expression)
.Build());
} }
var includePrivate = expression.Module == _syntaxTree.Metadata.ModuleName; var includePrivate = expression.Module == _syntaxTree.Metadata.ModuleName;
@@ -638,7 +654,10 @@ public sealed class TypeChecker
return new FuncIdentifierNode(type, expression.Module, expression.Name, function.ExternSymbol); return new FuncIdentifierNode(type, expression.Module, expression.Name, function.ExternSymbol);
} }
throw new TypeCheckerException(Diagnostic.Error($"No exported symbol {expression.Name} not found in module {expression.Module}").At(expression).Build()); throw new TypeCheckerException(Diagnostic
.Error($"No exported symbol {expression.Name} not found in module {expression.Module}")
.At(expression)
.Build());
} }
private ExpressionNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType) private ExpressionNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType)
@@ -825,16 +844,74 @@ public sealed class TypeChecker
ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType)), ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType)),
PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType)), PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType)),
StringTypeSyntax => new StringTypeNode(), StringTypeSyntax => new StringTypeNode(),
TemplateTypeSyntax template => ResolveTemplateType(template),
VoidTypeSyntax => new VoidTypeNode(), VoidTypeSyntax => new VoidTypeNode(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}") _ => throw new NotSupportedException($"Unknown type syntax: {type}")
}; };
} }
private readonly Dictionary<(string Module, string Name), TypeNode> _typeCache = new(); private StructTypeNode ResolveTemplateType(TemplateTypeSyntax template)
private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; {
// todo(nub31): Add module support for template types
var definition = _syntaxTree.Definitions
.OfType<StructTemplateSyntax>()
.FirstOrDefault(x => x.Name == template.Name);
if (definition == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Template {template.Name} does not exist").At(template).Build());
}
if (definition.TemplateArguments.Count != template.TemplateParameters.Count)
{
throw new TypeCheckerException(Diagnostic
.Error($"Template {template.Name} has {definition.TemplateArguments.Count} arguments, but usage only has {template.TemplateParameters.Count} parameters")
.At(template)
.Build());
}
var scope = new Scope();
for (var i = 0; i < definition.TemplateArguments.Count; i++)
{
scope.DeclareGenericType(definition.TemplateArguments[i], ResolveType(template.TemplateParameters[i]));
}
_scopes.Push(scope);
var fields = definition.Fields
.Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue))
.ToList();
var functions = definition.Functions
.Select(x => new StructTypeFunc(x.Name, x.Hook, new FuncTypeNode(x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(), ResolveType(x.Signature.ReturnType))))
.ToList();
var name = $"{template.Name}.{NameMangler.Mangle(template.TemplateParameters.Select(ResolveType))}";
var type = new StructTypeNode(template.Module, name, fields, functions);
var checkedFields = definition.Fields.Select(CheckStructField).ToList();
var checkedFunctions = definition.Functions.Select(x => CheckStructFunc(type, x, scope)).ToList();
Definitions.Add(new StructNode(type, template.Module, name, checkedFields, checkedFunctions));
_scopes.Pop();
return type;
}
private TypeNode ResolveCustomType(CustomTypeSyntax customType) private TypeNode ResolveCustomType(CustomTypeSyntax customType)
{ {
if (_syntaxTree.Metadata.ModuleName == customType.Module && _scopes.TryPeek(out var scope))
{
var generic = scope.LookupGenericType(customType.Name);
if (generic != null)
{
return generic;
}
}
var key = (customType.Module, customType.Name); var key = (customType.Module, customType.Name);
if (_typeCache.TryGetValue(key, out var cachedType)) if (_typeCache.TryGetValue(key, out var cachedType))
@@ -853,7 +930,11 @@ public sealed class TypeChecker
{ {
if (!_visibleModules.TryGetValue(customType.Module, out var module)) if (!_visibleModules.TryGetValue(customType.Module, out var module))
{ {
throw new TypeCheckerException(Diagnostic.Error($"Module {customType.Module} not found").WithHelp($"import \"{customType.Module}\"").At(customType).Build()); throw new TypeCheckerException(Diagnostic
.Error($"Module {customType.Module} not found")
.WithHelp($"import \"{customType.Module}\"")
.At(customType)
.Build());
} }
var includePrivate = customType.Module == _syntaxTree.Metadata.ModuleName; var includePrivate = customType.Module == _syntaxTree.Metadata.ModuleName;
@@ -878,7 +959,10 @@ public sealed class TypeChecker
return result; return result;
} }
throw new TypeCheckerException(Diagnostic.Error($"Type {customType.Name} not found in module {customType.Module}").At(customType).Build()); throw new TypeCheckerException(Diagnostic
.Error($"Type {customType.Name} not found in module {customType.Module}")
.At(customType)
.Build());
} }
finally finally
{ {
@@ -887,19 +971,20 @@ public sealed class TypeChecker
} }
} }
public enum IdentifierKind public enum VariableKind
{ {
Variable, Variable,
FunctionParameter FunctionParameter
} }
public record Identifier(string Name, TypeNode Type, IdentifierKind Kind); public record Variable(string Name, TypeNode Type, VariableKind Kind);
public class Scope(Scope? parent = null) public class Scope(Scope? parent = null)
{ {
private readonly List<Identifier> _variables = []; private readonly List<Variable> _variables = [];
private readonly Dictionary<string, TypeNode> _typeArguments = [];
public Identifier? Lookup(string name) public Variable? LookupVariable(string name)
{ {
var variable = _variables.FirstOrDefault(x => x.Name == name); var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null) if (variable != null)
@@ -907,12 +992,22 @@ public class Scope(Scope? parent = null)
return variable; return variable;
} }
return parent?.Lookup(name); return parent?.LookupVariable(name);
} }
public void Declare(Identifier identifier) public void DeclareVariable(Variable variable)
{ {
_variables.Add(identifier); _variables.Add(variable);
}
public void DeclareGenericType(string typeArgument, TypeNode type)
{
_typeArguments[typeArgument] = type;
}
public TypeNode? LookupGenericType(string typeArgument)
{
return _typeArguments.GetValueOrDefault(typeArgument);
} }
public Scope SubScope() public Scope SubScope()

View File

@@ -1,8 +1,8 @@
module "main" module "main"
extern "puts" func puts(text: cstring) extern "puts" func puts(text: cstring)
extern "malloc" func malloc(size: u64): ^u64 extern "malloc" func malloc(size: u64): ^void
extern "free" func free(address: ^u64) extern "free" func free(address: ^void)
struct Human struct Human
{ {
@@ -11,29 +11,29 @@ struct Human
extern "main" func main(args: []cstring): i64 extern "main" func main(args: []cstring): i64
{ {
let x: ref = {} let x: ref<Human> = {}
test(x) test(x)
return 0 return 0
} }
func test(x: ref) func test(x: ref<Human>)
{ {
} }
struct ref struct ref<T>
{ {
value: ^u64 value: ^T
count: ^u64 count: ^u64
@oncreate @oncreate
func on_create() func on_create()
{ {
puts("on_create") puts("on_create")
this.value = malloc(8) this.value = @interpret(^T, malloc(@size(T)))
this.count = malloc(8) this.count = @interpret(^u64, malloc(@size(u64)))
this.count^ = 1 this.count^ = 1
} }
@@ -52,8 +52,8 @@ struct ref
if this.count^ <= 0 if this.count^ <= 0
{ {
puts("free") puts("free")
free(this.value) free(@interpret(^void, this.value))
free(this.count) free(@interpret(^void, this.count))
} }
} }
} }