This repository has been archived on 2025-10-23. You can view files and clone it, but cannot push or open issues or pull requests.
Files
nub-lang-archive/compiler/NubLang/Generation/Generator.cs
nub31 4a01fbc306 ...
2025-10-21 20:22:18 +02:00

547 lines
21 KiB
C#

using LLVMSharp.Interop;
using NubLang.Ast;
namespace NubLang.Generation;
public class Generator
{
private readonly CompilationUnit _compilationUnit;
private bool _done;
private readonly Dictionary<string, LLVMTypeRef> _namedTypes = new();
private readonly Dictionary<string, LLVMValueRef> _namedValues = [];
private readonly LLVMModuleRef _module;
private LLVMBuilderRef _builder;
private LLVMValueRef _currentFunction;
public Generator(CompilationUnit compilationUnit, string fileName)
{
_compilationUnit = compilationUnit;
_module = LLVMModuleRef.CreateWithName(fileName);
}
private LLVMTypeRef MapType(NubType nubType)
{
return nubType switch
{
NubArrayType nubArrayType => LLVMTypeRef.CreatePointer(MapType(nubArrayType.ElementType), 0),
NubBoolType => LLVMTypeRef.Int1,
NubConstArrayType nubConstArrayType => LLVMTypeRef.CreatePointer(MapType(nubConstArrayType.ElementType), 0),
NubCStringType => LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0),
NubFloatType nubFloatType => nubFloatType.Width switch
{
32 => LLVMTypeRef.Float,
64 => LLVMTypeRef.Double,
_ => throw new ArgumentOutOfRangeException(nameof(nubFloatType.Width))
},
NubFuncType nubFuncType => LLVMTypeRef.CreateFunction(MapType(nubFuncType.ReturnType), nubFuncType.Parameters.Select(MapType).ToArray()),
NubIntType nubIntType => LLVMTypeRef.CreateInt((uint)nubIntType.Width),
NubPointerType nubPointerType => LLVMTypeRef.CreatePointer(MapType(nubPointerType.BaseType), 0),
NubSliceType nubSliceType => GetOrCreateSliceType(nubSliceType),
NubStringType => GetOrCreateStringType(),
NubStructType nubStructType => GetOrCreateStructType(nubStructType),
NubVoidType => LLVMTypeRef.Void,
_ => throw new ArgumentOutOfRangeException(nameof(nubType))
};
}
private LLVMTypeRef GetOrCreateStructType(NubStructType structType)
{
var key = $"struct.{structType.Module}.{structType.Name}";
if (!_namedTypes.TryGetValue(key, out var llvmType))
{
llvmType = _module.Context.CreateNamedStruct(structType.Name);
_namedTypes[key] = llvmType;
var fields = structType.Fields.Select(f => MapType(f.Type)).ToArray();
llvmType.StructSetBody(fields, false);
}
return llvmType;
}
private LLVMTypeRef GetOrCreateSliceType(NubSliceType sliceType)
{
var key = $"slice.{sliceType.ElementType}";
if (!_namedTypes.TryGetValue(key, out var llvmType))
{
llvmType = LLVMTypeRef.CreateStruct([LLVMTypeRef.CreatePointer(MapType(sliceType.ElementType), 0), LLVMTypeRef.Int64], false);
_namedTypes[key] = llvmType;
}
return llvmType;
}
private LLVMTypeRef GetOrCreateStringType()
{
const string key = "string";
if (!_namedTypes.TryGetValue(key, out var llvmType))
{
llvmType = _module.Context.CreateNamedStruct("string");
llvmType.StructSetBody([LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), LLVMTypeRef.Int64], false);
_namedTypes[key] = llvmType;
}
return llvmType;
}
public LLVMModuleRef Generate()
{
if (_done)
{
return _module;
}
_done = true;
_builder = _module.Context.CreateBuilder();
foreach (var funcNode in _compilationUnit.Functions)
{
EmitFunction(funcNode);
}
if (!_module.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out var error))
{
Console.WriteLine($"Invalid LLVM module: {error}");
// throw new InvalidOperationException($"Invalid LLVM module: {error}");
}
return _module;
}
private void EmitFunction(FuncNode funcNode)
{
var parameters = funcNode.Prototype.Parameters.Select(x => MapType(x.Type)).ToArray();
var functionType = LLVMTypeRef.CreateFunction(MapType(funcNode.Prototype.ReturnType), parameters);
_currentFunction = _module.AddFunction(funcNode.Name, functionType);
if (funcNode.Body == null)
{
return;
}
_namedValues.Clear();
var entryBlock = _currentFunction.AppendBasicBlock("entry");
_builder.PositionAtEnd(entryBlock);
for (var i = 0; i < funcNode.Prototype.Parameters.Count; i++)
{
var param = funcNode.Prototype.Parameters[i];
var llvmParam = _currentFunction.GetParam((uint)i);
llvmParam.Name = param.Name;
var alloca = _builder.BuildAlloca(MapType(param.Type));
_builder.BuildStore(llvmParam, alloca);
_namedValues[param.Name] = alloca;
}
EmitBlock(funcNode.Body);
}
private void EmitBlock(BlockNode blockNode)
{
foreach (var statement in blockNode.Statements)
{
EmitStatement(statement);
}
}
private void EmitStatement(StatementNode statement)
{
switch (statement)
{
case AssignmentNode assignmentNode:
EmitAssignment(assignmentNode);
break;
case BlockNode blockNode:
EmitBlock(blockNode);
break;
case BreakNode breakNode:
EmitBreak(breakNode);
break;
case ContinueNode continueNode:
EmitContinue(continueNode);
break;
case DeferNode deferNode:
EmitDefer(deferNode);
break;
case IfNode ifNode:
EmitIf(ifNode);
break;
case ReturnNode returnNode:
EmitReturn(returnNode);
break;
case StatementFuncCallNode statementFuncCallNode:
EmitStatementFuncCall(statementFuncCallNode);
break;
case VariableDeclarationNode variableDeclarationNode:
EmitVariableDeclaration(variableDeclarationNode);
break;
case WhileNode whileNode:
EmitWhile(whileNode);
break;
default:
throw new ArgumentOutOfRangeException(nameof(statement));
}
}
private void EmitAssignment(AssignmentNode assignmentNode)
{
var target = EmitLValue(assignmentNode.Target);
var value = EmitExpression(assignmentNode.Value);
_builder.BuildStore(value, target);
}
private void EmitBreak(BreakNode breakNode)
{
throw new NotImplementedException();
}
private void EmitContinue(ContinueNode continueNode)
{
throw new NotImplementedException();
}
private void EmitDefer(DeferNode deferNode)
{
throw new NotImplementedException();
}
private void EmitIf(IfNode ifNode)
{
throw new NotImplementedException();
}
private void EmitReturn(ReturnNode returnNode)
{
if (returnNode.Value == null)
{
_builder.BuildRetVoid();
}
else
{
var returnValue = EmitExpression(returnNode.Value);
_builder.BuildRet(returnValue);
}
}
private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode)
{
EmitFuncCall(statementFuncCallNode.FuncCall);
}
private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode)
{
if (variableDeclarationNode.Assignment is LValueExpressionNode lValueExpressionNode)
{
var value = EmitLValue(lValueExpressionNode);
_namedValues[variableDeclarationNode.Name] = value;
return;
}
var allocaType = MapType(variableDeclarationNode.Type);
var alloca = _builder.BuildAlloca(allocaType);
if (variableDeclarationNode.Assignment != null)
{
var initValue = EmitExpression(variableDeclarationNode.Assignment);
_builder.BuildStore(initValue, alloca);
}
_namedValues[variableDeclarationNode.Name] = alloca;
}
private void EmitWhile(WhileNode whileNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitExpression(ExpressionNode expressionNode)
{
switch (expressionNode)
{
case LValueExpressionNode lvalue:
{
var value = EmitLValue(lvalue);
return _builder.BuildLoad2(MapType(lvalue.Type), value);
}
case RValueExpressionNode rvalue:
{
return EmitRValue(rvalue);
}
default:
{
throw new ArgumentOutOfRangeException(nameof(expressionNode));
}
}
}
private LLVMValueRef EmitLValue(LValueExpressionNode lValueNode)
{
return lValueNode switch
{
ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode),
ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode),
DereferenceNode dereferenceNode => EmitDereference(dereferenceNode),
LValueIdentifierNode lValueIdentifierNode => EmitLValueIdentifier(lValueIdentifierNode),
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
_ => throw new ArgumentOutOfRangeException(nameof(lValueNode))
};
}
private LLVMValueRef EmitRValue(RValueExpressionNode rValueNode)
{
return rValueNode switch
{
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
ArrayInitializerNode arrayInitializerNode => EmitArrayInitializer(arrayInitializerNode),
BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(binaryExpressionNode),
BoolLiteralNode boolLiteralNode => EmitBoolLiteral(boolLiteralNode),
ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode),
ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode),
CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode),
Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode),
Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode),
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode),
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
FuncIdentifierNode funcIdentifierNode => EmitFuncIdentifier(funcIdentifierNode),
IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode),
RValueIdentifierNode rValueIdentifierNode => EmitRValueIdentifier(rValueIdentifierNode),
SizeBuiltinNode sizeBuiltinNode => EmitSizeBuiltin(sizeBuiltinNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode),
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode),
_ => throw new ArgumentOutOfRangeException(nameof(rValueNode))
};
}
private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode)
{
var arrayType = (NubArrayType)arrayIndexAccessNode.Target.Type;
var elementType = MapType(arrayType.ElementType);
var target = EmitExpression(arrayIndexAccessNode.Target);
var index = EmitExpression(arrayIndexAccessNode.Index);
return _builder.BuildGEP2(elementType, target, [index]);
}
private LLVMValueRef EmitArrayInitializer(ArrayInitializerNode arrayInitializerNode)
{
var capacity = EmitExpression(arrayInitializerNode.Capacity);
return _builder.BuildArrayAlloca(MapType(arrayInitializerNode.ElementType), capacity);
}
private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode binaryExpressionNode)
{
var left = EmitExpression(binaryExpressionNode.Left);
var right = EmitExpression(binaryExpressionNode.Right);
var leftType = binaryExpressionNode.Left.Type;
if (leftType is NubIntType)
{
return binaryExpressionNode.Operator switch
{
BinaryOperator.Plus => _builder.BuildAdd(left, right),
BinaryOperator.Minus => _builder.BuildSub(left, right),
BinaryOperator.Multiply => _builder.BuildMul(left, right),
BinaryOperator.Divide => _builder.BuildSDiv(left, right),
BinaryOperator.Modulo => _builder.BuildSRem(left, right),
BinaryOperator.Equal => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right),
BinaryOperator.NotEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right),
BinaryOperator.LessThan => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSLT, left, right),
BinaryOperator.LessThanOrEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSLE, left, right),
BinaryOperator.GreaterThan => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSGT, left, right),
BinaryOperator.GreaterThanOrEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSGE, left, right),
BinaryOperator.BitwiseAnd => _builder.BuildAnd(left, right),
BinaryOperator.BitwiseOr => _builder.BuildOr(left, right),
BinaryOperator.BitwiseXor => _builder.BuildXor(left, right),
BinaryOperator.LeftShift => _builder.BuildShl(left, right),
BinaryOperator.RightShift => _builder.BuildAShr(left, right),
_ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for int")
};
}
if (leftType is NubFloatType)
{
return binaryExpressionNode.Operator switch
{
BinaryOperator.Plus => _builder.BuildFAdd(left, right),
BinaryOperator.Minus => _builder.BuildFSub(left, right),
BinaryOperator.Multiply => _builder.BuildFMul(left, right),
BinaryOperator.Divide => _builder.BuildFDiv(left, right),
BinaryOperator.Modulo => _builder.BuildFRem(left, right),
BinaryOperator.Equal => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOEQ, left, right),
BinaryOperator.NotEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealONE, left, right),
BinaryOperator.LessThan => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLT, left, right),
BinaryOperator.LessThanOrEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLE, left, right),
BinaryOperator.GreaterThan => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGT, left, right),
BinaryOperator.GreaterThanOrEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGE, left, right),
_ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for float")
};
}
if (leftType is NubBoolType)
{
return binaryExpressionNode.Operator switch
{
BinaryOperator.LogicalAnd => _builder.BuildAnd(left, right),
BinaryOperator.LogicalOr => _builder.BuildOr(left, right),
BinaryOperator.Equal => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right),
BinaryOperator.NotEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right),
_ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for bool")
};
}
throw new NotSupportedException($"Binary operations for type {leftType} not supported");
}
private LLVMValueRef EmitBoolLiteral(BoolLiteralNode boolLiteralNode)
{
return LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, boolLiteralNode.Value ? 1ul : 0ul);
}
private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode)
{
var arrayType = (NubConstArrayType)constArrayIndexAccessNode.Target.Type;
var target = EmitExpression(constArrayIndexAccessNode.Target);
var index = EmitExpression(constArrayIndexAccessNode.Index);
return _builder.BuildGEP2(MapType(arrayType), target, [LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, 0), index]);
}
private LLVMValueRef EmitConvertFloat(ConvertFloatNode convertFloatNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitConvertInt(ConvertIntNode convertIntNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode)
{
return _builder.BuildGlobalStringPtr(cStringLiteralNode.Value, "str");
}
private LLVMValueRef EmitDereference(DereferenceNode dereferenceNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitFloat32Literal(Float32LiteralNode float32LiteralNode)
{
return LLVMValueRef.CreateConstReal(LLVMTypeRef.Float, float32LiteralNode.Value);
}
private LLVMValueRef EmitFloat64Literal(Float64LiteralNode float64LiteralNode)
{
return LLVMValueRef.CreateConstReal(LLVMTypeRef.Double, float64LiteralNode.Value);
}
private LLVMValueRef EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitFuncCall(FuncCallNode funcCallNode)
{
var function = EmitExpression(funcCallNode.Expression);
var args = funcCallNode.Parameters.Select(EmitExpression).ToArray();
return _builder.BuildCall2(MapType(funcCallNode.Expression.Type), function, args);
}
private LLVMValueRef EmitFuncIdentifier(FuncIdentifierNode funcIdentifierNode)
{
return _module.GetNamedFunction(funcIdentifierNode.Name);
}
private LLVMValueRef EmitIntLiteral(IntLiteralNode intLiteralNode)
{
return LLVMValueRef.CreateConstInt(MapType(intLiteralNode.Type), (ulong)intLiteralNode.Value, true);
}
private LLVMValueRef EmitAddressOf(AddressOfNode addressOfNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitLValueIdentifier(LValueIdentifierNode lValueIdentifierNode)
{
return _namedValues[lValueIdentifierNode.Name];
}
private LLVMValueRef EmitRValueIdentifier(RValueIdentifierNode rValueIdentifierNode)
{
return _builder.BuildLoad2(MapType(rValueIdentifierNode.Type), _namedValues[rValueIdentifierNode.Name], rValueIdentifierNode.Name);
}
private LLVMValueRef EmitSizeBuiltin(SizeBuiltinNode sizeBuiltinNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode)
{
throw new NotImplementedException();
}
private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode)
{
var type = (NubStructType)structFieldAccessNode.Target.Type;
var target = EmitLValue(structFieldAccessNode.Target);
var fieldIndex = type.Fields.FindIndex(x => x.Name == structFieldAccessNode.Field);
return _builder.BuildStructGEP2(MapType(structFieldAccessNode.Target.Type), target, (uint)fieldIndex);
}
private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode)
{
var type = MapType(structInitializerNode.StructType);
var ptr = _builder.BuildAlloca(type);
foreach (var initializer in structInitializerNode.Initializers)
{
var value = EmitExpression(initializer.Value);
var fieldIndex = structInitializerNode.StructType.Fields.FindIndex(x => x.Name == initializer.Key);
var fieldPtr = _builder.BuildStructGEP2(type, ptr, (uint)fieldIndex);
_builder.BuildStore(value, fieldPtr);
}
return ptr;
}
private LLVMValueRef EmitUIntLiteral(UIntLiteralNode uIntLiteralNode)
{
return LLVMValueRef.CreateConstInt(MapType(uIntLiteralNode.Type), uIntLiteralNode.Value);
}
private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode)
{
var operand = EmitExpression(unaryExpressionNode.Operand);
return unaryExpressionNode.Operator switch
{
UnaryOperator.Negate when unaryExpressionNode.Operand.Type is NubIntType => _builder.BuildNeg(operand),
UnaryOperator.Negate when unaryExpressionNode.Operand.Type is NubFloatType => _builder.BuildFNeg(operand),
UnaryOperator.Invert => _builder.BuildNot(operand),
_ => throw new NotImplementedException($"Unary operator {unaryExpressionNode.Operator} not implemented")
};
}
}