406 lines
14 KiB
C#
406 lines
14 KiB
C#
using LLVMSharp.Interop;
|
|
using NubLang.Ast;
|
|
|
|
namespace NubLang.Generation;
|
|
|
|
public class Generator
|
|
{
|
|
private readonly CompilationUnit _compilationUnit;
|
|
|
|
public Generator(CompilationUnit compilationUnit)
|
|
{
|
|
_compilationUnit = compilationUnit;
|
|
|
|
LLVM.LinkInMCJIT();
|
|
LLVM.InitializeX86TargetInfo();
|
|
LLVM.InitializeX86Target();
|
|
LLVM.InitializeX86TargetMC();
|
|
LLVM.InitializeX86AsmPrinter();
|
|
|
|
_module = LLVMModuleRef.CreateWithName("test");
|
|
}
|
|
|
|
private bool _done;
|
|
private readonly Dictionary<string, LLVMTypeRef> _namedTypes = new();
|
|
private readonly LLVMModuleRef _module;
|
|
|
|
private LLVMBuilderRef _builder;
|
|
private LLVMValueRef _currentFunction;
|
|
private Dictionary<string, LLVMValueRef> _namedValues = [];
|
|
|
|
private LLVMTypeRef MapType(NubType nubType)
|
|
{
|
|
return nubType switch
|
|
{
|
|
NubArrayType nubArrayType => LLVMTypeRef.CreatePointer(MapType(nubArrayType.ElementType), 0),
|
|
NubBoolType => LLVMTypeRef.Int1,
|
|
NubConstArrayType nubConstArrayType => LLVMTypeRef.CreateArray(MapType(nubConstArrayType.ElementType), (uint)nubConstArrayType.Size),
|
|
NubCStringType => LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0),
|
|
NubFloatType nubFloatType => nubFloatType.Width switch
|
|
{
|
|
32 => LLVMTypeRef.Float,
|
|
64 => LLVMTypeRef.Double,
|
|
_ => throw new ArgumentOutOfRangeException(nameof(nubFloatType.Width))
|
|
},
|
|
NubFuncType nubFuncType => LLVMTypeRef.CreateFunction(MapType(nubFuncType.ReturnType), nubFuncType.Parameters.Select(MapType).ToArray()),
|
|
NubIntType nubIntType => LLVMTypeRef.CreateInt((uint)nubIntType.Width),
|
|
NubPointerType nubPointerType => LLVMTypeRef.CreatePointer(MapType(nubPointerType.BaseType), 0),
|
|
NubSliceType nubSliceType => GetOrCreateSliceType(nubSliceType),
|
|
NubStringType => GetOrCreateStringType(),
|
|
NubStructType nubStructType => GetOrCreateStructType(nubStructType),
|
|
NubVoidType => LLVMTypeRef.Void,
|
|
_ => throw new ArgumentOutOfRangeException(nameof(nubType))
|
|
};
|
|
}
|
|
|
|
private LLVMTypeRef GetOrCreateStructType(NubStructType structType)
|
|
{
|
|
var key = $"struct.{structType.Module}.{structType.Name}";
|
|
if (!_namedTypes.TryGetValue(key, out var llvmType))
|
|
{
|
|
llvmType = _module.Context.CreateNamedStruct(structType.Name);
|
|
_namedTypes[key] = llvmType;
|
|
var fields = structType.Fields.Select(f => MapType(f.Type)).ToArray();
|
|
llvmType.StructSetBody(fields, false);
|
|
}
|
|
|
|
return llvmType;
|
|
}
|
|
|
|
private LLVMTypeRef GetOrCreateSliceType(NubSliceType sliceType)
|
|
{
|
|
var key = $"slice.{sliceType.ElementType}";
|
|
if (!_namedTypes.TryGetValue(key, out var llvmType))
|
|
{
|
|
llvmType = LLVMTypeRef.CreateStruct([LLVMTypeRef.CreatePointer(MapType(sliceType.ElementType), 0), LLVMTypeRef.Int64], false);
|
|
_namedTypes[key] = llvmType;
|
|
}
|
|
|
|
return llvmType;
|
|
}
|
|
|
|
private LLVMTypeRef GetOrCreateStringType()
|
|
{
|
|
const string key = "string";
|
|
if (!_namedTypes.TryGetValue(key, out var llvmType))
|
|
{
|
|
llvmType = _module.Context.CreateNamedStruct("string");
|
|
llvmType.StructSetBody([LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), LLVMTypeRef.Int64], false);
|
|
_namedTypes[key] = llvmType;
|
|
}
|
|
|
|
return llvmType;
|
|
}
|
|
|
|
public string Generate()
|
|
{
|
|
if (_done)
|
|
{
|
|
return _module.PrintToString();
|
|
}
|
|
|
|
_done = true;
|
|
_builder = _module.Context.CreateBuilder();
|
|
|
|
foreach (var funcNode in _compilationUnit.Functions)
|
|
{
|
|
EmitFunction(funcNode);
|
|
}
|
|
|
|
if (_module.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out string error))
|
|
{
|
|
return _module.PrintToString();
|
|
}
|
|
else
|
|
{
|
|
throw new InvalidOperationException($"Invalid LLVM module: {error}");
|
|
}
|
|
}
|
|
|
|
private void EmitFunction(FuncNode funcNode)
|
|
{
|
|
var parameters = funcNode.Prototype.Parameters.Select(x => MapType(x.Type)).ToArray();
|
|
var functionType = LLVMTypeRef.CreateFunction(MapType(funcNode.Prototype.ReturnType), parameters);
|
|
|
|
_currentFunction = _module.AddFunction(funcNode.Name, functionType);
|
|
|
|
if (funcNode.Body == null)
|
|
{
|
|
return;
|
|
}
|
|
|
|
_namedValues.Clear();
|
|
|
|
var entryBlock = _currentFunction.AppendBasicBlock("entry");
|
|
_builder.PositionAtEnd(entryBlock);
|
|
|
|
for (var i = 0; i < funcNode.Prototype.Parameters.Count; i++)
|
|
{
|
|
var param = funcNode.Prototype.Parameters[i];
|
|
var llvmParam = _currentFunction.GetParam((uint)i);
|
|
llvmParam.Name = param.Name;
|
|
|
|
var alloca = _builder.BuildAlloca(MapType(param.Type), param.Name);
|
|
_builder.BuildStore(llvmParam, alloca);
|
|
_namedValues[param.Name] = alloca;
|
|
}
|
|
|
|
EmitBlock(funcNode.Body);
|
|
}
|
|
|
|
private void EmitBlock(BlockNode blockNode)
|
|
{
|
|
foreach (var statement in blockNode.Statements)
|
|
{
|
|
EmitStatement(statement);
|
|
}
|
|
}
|
|
|
|
private void EmitStatement(StatementNode statement)
|
|
{
|
|
switch (statement)
|
|
{
|
|
case AssignmentNode assignmentNode:
|
|
EmitAssignment(assignmentNode);
|
|
break;
|
|
case BlockNode blockNode:
|
|
EmitBlock(blockNode);
|
|
break;
|
|
case BreakNode breakNode:
|
|
EmitBreak(breakNode);
|
|
break;
|
|
case ContinueNode continueNode:
|
|
EmitContinue(continueNode);
|
|
break;
|
|
case DeferNode deferNode:
|
|
EmitDefer(deferNode);
|
|
break;
|
|
case IfNode ifNode:
|
|
EmitIf(ifNode);
|
|
break;
|
|
case ReturnNode returnNode:
|
|
EmitReturn(returnNode);
|
|
break;
|
|
case StatementFuncCallNode statementFuncCallNode:
|
|
EmitStatementFuncCall(statementFuncCallNode);
|
|
break;
|
|
case VariableDeclarationNode variableDeclarationNode:
|
|
EmitVariableDeclaration(variableDeclarationNode);
|
|
break;
|
|
case WhileNode whileNode:
|
|
EmitWhile(whileNode);
|
|
break;
|
|
default:
|
|
throw new ArgumentOutOfRangeException(nameof(statement));
|
|
}
|
|
}
|
|
|
|
private void EmitAssignment(AssignmentNode assignmentNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitBreak(BreakNode breakNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitContinue(ContinueNode continueNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitDefer(DeferNode deferNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitIf(IfNode ifNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitReturn(ReturnNode returnNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private void EmitWhile(WhileNode whileNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitExpression(ExpressionNode expressionNode)
|
|
{
|
|
return expressionNode switch
|
|
{
|
|
ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode),
|
|
ArrayInitializerNode arrayInitializerNode => EmitArrayInitializer(arrayInitializerNode),
|
|
BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(binaryExpressionNode),
|
|
BoolLiteralNode boolLiteralNode => EmitBoolLiteral(boolLiteralNode),
|
|
ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode),
|
|
ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode),
|
|
ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode),
|
|
ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode),
|
|
CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode),
|
|
DereferenceNode dereferenceNode => EmitDereference(dereferenceNode),
|
|
Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode),
|
|
Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode),
|
|
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode),
|
|
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
|
|
FuncIdentifierNode funcIdentifierNode => EmitFuncIdentifier(funcIdentifierNode),
|
|
IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode),
|
|
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
|
|
LValueIdentifierNode lValueIdentifierNode => EmitLValueIdentifier(lValueIdentifierNode),
|
|
RValueIdentifierNode rValueIdentifierNode => EmitRValueIdentifier(rValueIdentifierNode),
|
|
SizeBuiltinNode sizeBuiltinNode => EmitSizeBuiltin(sizeBuiltinNode),
|
|
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode),
|
|
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
|
|
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
|
|
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
|
|
UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode),
|
|
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode),
|
|
_ => throw new ArgumentOutOfRangeException(nameof(expressionNode))
|
|
};
|
|
}
|
|
|
|
private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitArrayInitializer(ArrayInitializerNode arrayInitializerNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode binaryExpressionNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitBoolLiteral(BoolLiteralNode boolLiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitConstArrayInitializer(ConstArrayInitializerNode constArrayInitializerNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitConvertFloat(ConvertFloatNode convertFloatNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitConvertInt(ConvertIntNode convertIntNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitDereference(DereferenceNode dereferenceNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitFloat32Literal(Float32LiteralNode float32LiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitFloat64Literal(Float64LiteralNode float64LiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitFuncCall(FuncCallNode funcCallNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitFuncIdentifier(FuncIdentifierNode funcIdentifierNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitIntLiteral(IntLiteralNode intLiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitAddressOf(AddressOfNode addressOfNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitLValueIdentifier(LValueIdentifierNode lValueIdentifierNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitRValueIdentifier(RValueIdentifierNode rValueIdentifierNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitSizeBuiltin(SizeBuiltinNode sizeBuiltinNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitUIntLiteral(UIntLiteralNode uIntLiteralNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
|
|
private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode)
|
|
{
|
|
throw new NotImplementedException();
|
|
}
|
|
} |