About this Workshop

In this workshop you will learn a bit about what is LLVM and MLIR, by implementing a really simple programming language compiler. This language only has variables of type i64, and only supports simple functions with arguments, basic arithmetic operations and if else statements.

What is LLVM?

To know about MLIR, you need to know what is LLVM.

LLVM is what one would call a compiler backend, made of many reusable components, it deals with the intrinsicacies of each CPU architecture, providing a higher level API for programmers to work with, kind of like a "game engine" but for making compiled (or JIT) programming languages, so you don't have to deal with the lowest level details (lower than what some people call low-level anyway).

To abstract away all the specific cpu details, LLVM is a "target", with high level pseudo-assembly, what we call the LLVM IR. We can consider LLVM an abstract machine (much like C), it has infinite registers and some other details, to program this machine we use LLVM IR (IR means Intermediate Representation).

An example of this IR (snippet from A Gentle Introduction to LLVM IR):

define i32 @pow(i32 %x, i32 %y) {
  ; Create slots for r and the index, and initialize them.
  ; This is equivalent to something like
  ;   int i = 0, r = 1;
  ; in C.
  %r = alloca i32
  %i = alloca i32
  store i32 1, ptr %r
  store i32 0, ptr %i
  br label %loop_start

loop_start:
  ; Load the index and check if it equals y.
  %i.check = load i32, ptr %i
  %done = icmp eq i32 %i.check, %y
  br i1 %done, label %exit, label %loop

loop:
  ; r *= x
  %r.old = load i32, ptr %r
  %r.new = mul i32 %r.old, %x
  store i32 %r.new, ptr %r

  ; i += 1
  %i.old = load i32, ptr %i
  %i.new = add i32 %i.old, 1
  store i32 %i.new, ptr %i

  br label %loop_start

exit:
  %r.ret = load i32, ptr %r
  ret i32 %r.ret
}

To model the specific attributes of the CPU we want to target, we define a data layout and specify a target triple, such as x86_64-apple-macosx10.7.0

The data layout specifies how data is laid out in memory, such as the alignment and size of different data types, whether its big or little endian, etc.

The LLVM IR uses the following structure:

A module is a compilation unit, and within a module there are functions, globals and more.

A function is made up of blocks, each block contains a sequence of instructions, these instructions are always run sequentially and there is no control flow within a single block, to implement control flow you must jump to other blocks.

It uses Single Static Assignment form (SSA form), which means a variable is assigned once and only once, this allows LLVM to do a lot of optimizations.

Due to this, when control flow is involved, one must take care to dominate all the uses, for example you may have a target block that has 2 predecessors (2 blocks where each ends up jumping to this target block), each of these predecessors can define variables that will be used in this target block, to "unify" those variables (to do dependency analysis) in the target block you must use the phi instruction, which defines a PHI node.

One can avoid using PHI nodes by relying on allocas, a alloca is a reservation of the stack space, basically you give it a size and align and it gives you a pointer to this allocation, you can then simply load/store that pointer, from any branch and you don't have to deal with PHI nodes this way, this is what most languages do, like Rust, they rely on LLVM to later optimize the allocas into register uses whenever possible.

What is MLIR?

MLIR is kind of a IR of IRs, and it supports many of them using "dialects". For example, you may have heard of NVVM IR (CUDA), MLIR supports modeling it through the NVVM dialect (or ROCDL for AMD), but there is also a more generic and higher level GPU dialect.

Within MLIR, there is a dialect to model the LLVM IR itself, and also conversions and transformations from other dialects into the LLVM IR dialect.

With this, one can create a "dialect" that is high level and can be converted into a GPU kernel or a CPU program for example, this is kind of what the TOSA dialect does.

As the main page says "MLIR aims to address software fragmentation", and by defining multiple dialects and conversions between them, it's how it achieves that.

Some notable dialects:

  • Builtin: The builtin dialect contains a core set of Attributes, Operations, and Types that have wide applicability across a very large number of domains and abstractions. Many of the components of this dialect are also instrumental in the implementation of the core IR.
  • Affine: This dialect provides a powerful abstraction for affine operations and analyses.
  • Async: This dialect contains operations for modeling asynchronous execution.
  • SCF: The scf (structured control flow) dialect contains operations that represent control flow constructs such as if and for. Being structured means that the control flow has a structure unlike, for example, gotos or asserts.
  • CF: This dialect contains low-level, i.e. non-region based, control flow constructs. These constructs generally represent control flow directly on SSA blocks of a control flow graph.
  • LLVM: This dialect maps LLVM IR into MLIR by defining the corresponding operations and types. LLVM IR metadata is usually represented as MLIR attributes, which offer additional structure verification.
  • GPU: This dialect provides middle-level abstractions for launching GPU kernels following a programming model similar to that of CUDA or OpenCL.
  • Arith: The arith dialect is intended to hold basic integer and floating point mathematical operations. This includes unary, binary, and ternary arithmetic ops, bitwise and shift ops, cast ops, and compare ops. Operations in this dialect also accept vectors and tensors of integers or floats.
  • TOSA: TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements.
  • Func: This dialect contains operations surrounding high order function abstractions, such as calls.

The structure of the MLIR IR is the following:

A module defines a compile unit, the module is made up of one or multiple operations. An operation is made up of one or multiple regions. A region is made up of one or multiple blocks. A block is made up of one or multiple operations.

With this recursive structure, it can define the logic of all the IRs.

Example MLIR code, using multiple dialects:

module {
  func.func @foo() {
    %c0 = arith.constant 0 : index
    %0 = scf.while (%arg0 = %c0) : (index) -> f64 {
      %false = arith.constant false
      %cst = arith.constant 4.200000e+01 : f64
      scf.condition(%false) %cst : f64
    } do {
    ^bb0(%arg0: f64):
      %c42 = arith.constant 42 : index
      scf.yield %c42 : index
    }
    return
  }
}

In MLIR, blocks can have arguments, this is the MLIR solution to PHI nodes. If a target block uses a variable for multiple independent branches, add it as an argument and the jumps from the predecessors must pass it in the respective jump operation.

You can see in the code that there is a while loop. This is thanks to the SCF dialect, which provides high level control flow operations. If your target is LLVM, this dialect is then converted into blocks and LLVM dialect jumps.

In our case, we want to have a compiled program, so LLVM IR will be our target, this means we have to add passes to convert the multiple dialects we use into the LLVM dialect, and then convert the MLIR to LLVM IR and compile it. This is done either programatically or with mlir-opt and mlir-translate.

Other Learning Resources

These are extra resources, they aren't meant to be read now in the workshop but they are here for your convenience.

Resources marked with are best.

Talks, Presentations, & Videos

Useful code

MLIR Tutorial

Misc Resources

MLIR and melior Basics

To use MLIR with Rust, the following library is used: https://github.com/mlir-rs/melior

This page explains a bit how to use it.

The Context

#![allow(unused)]
fn main() {
 let context = Context::new();
}

The context is a opaque struct that holds all the created attributes, locations and more. It must be passed to nearly all the melior methods.

Location

#![allow(unused)]
fn main() {
// A location pointing to a file line col
let loc: Location<'c> = Location::new(&context, filename, line, column);
// An unknown location.
let loc = Location::unknown(&context);
}

All operations and arguments have a location in MLIR. If there is no real location, you can use the unknown method.

Module

The module is a compile unit. It internally holds a single operation with a single region with a single block. More specifically, a module is a builtin.module operation.

#![allow(unused)]
fn main() {
let module: Module<'c> = Module::new(Location::unknown(&context));
}

To add an operation to a module, you can do the following:

#![allow(unused)]
fn main() {
// body() returns a BlockRef. since you can only add operations to blocks.
module.body().append_operation(operation)
}

Operation

An operation is an instruction. It can hold regions, which themselves hold blocks. It also has attributes, operands, results and succesors.

  • Attributes are like configuration parameters for the operation.
  • Operands are the inputs, values.
  • Results are the result values the operation produces, it can be 1 or more.
  • Successors are blocks to branch into.

Types

Each dialect can define their own types. For example, the index dialect defines the index type:

#![allow(unused)]
fn main() {
let idx = Type::index(&context);
}

The builtin dialect defines some common types. They can be created with Type::<name> or with other structs, such as IntegerType:

#![allow(unused)]
fn main() {
let my_f16 = Type::float16(context);
let my_u64: Type<'c> = IntegerType::new(context, 64).into();
}

Attributes

Most operations accept or require attributes. For example the func.func operation requires a StringAttribute to define the function name, some other operations may require a TypeAttribute to pass type information for example.

#![allow(unused)]
fn main() {
let my_type_attr: Attribute<'c> =
    TypeAttribute::new(IntegerType::new(context, 64).into()).into();
}

In melior there are 4 ways to create a operation: Using ods, using a method from the dialect melior rust module or using the operation builder.

ODS

ODS is generated using tablegen and rust macros from melior side.

With ODS:

#![allow(unused)]
fn main() {
use melior::dialect::ods;

let my_alloca = block.append_operation(
    ods::llvm::alloca(context, res, array_size,
                      elem_type, location).into()
);
// Get the returned ptr
let ptr: Value<'c> = my_alloca.result(0).unwrap().into();
}

The dialect module

This is a handcrafted API, so it may miss a lot of operations:

#![allow(unused)]
fn main() {
let my_alloca = block.append_operation(
        melior::dialect::llvm::alloca(context, array_size, ptr_type,
                                      location, extra_options)
    );
// Get the returned ptr
let ptr: Value<'c> = my_alloca.result(0).unwrap().into();
}

The operation builder

#![allow(unused)]
fn main() {
let location = Location::unknown(&context);
let r#type = Type::index(&context);
let block = Block::new(&[(r#type, location)]);
let argument: Value = block.argument(0).unwrap().into();

let operands = vec![argument, argument, argument];
let operation = OperationBuilder::new("foo", Location::unknown(&context))
    .add_operands(&operands)
    .build()
    .unwrap();
}

Helper Traits

Some frequently used operations, mainly those in the llvm, arith and builtin dialects have a trait in melior to make it less verbose. It's a trait implemented on the Block so you can simply do block.load(..).

Region

A region holds one or multiple blocks. It depends on the operation whether there are 0 or more regions.

Usually multiple regions are used in higher level dialects, like SCF, which has while and for constructs. The CF dialect instead works with blocks.

A region is more isolated than a block. You can easily use a value from a predecessor block within a given block, but taking a value from another region that is not a parent requires passing it as an argument to the operation/block. This makes operations that work with regions like SCF a bit harder to work with in some contexts.

#![allow(unused)]
fn main() {
let region = Region::new();

// Add a block to the region.
let block_ref = region.append_block(Block::new(&[]));

// Here one would implement the function body

// pass the region to a operation.
let func_op = func::func(context, name, r#type, region, attributes, location);
}

Block

A block holds a sequence of operations. Control flow can only happen within the isolated operations but control returns always to the next operation within the block.

A block must always have a terminator, that is a operation that has the Terminator Trait. This is usually operations that do branching like cf.br or that diverge llvm.unreachable

#![allow(unused)]
fn main() {
// To create a block we must pass the arguments it accepts, it is an array of a tuple of (Type, Location)
let block = Block::new(&[
    (Type::float32(&context), Location::unknown(&context))
]);

// Get the first argument to use it in future operations:
let arg1: Value = block.argument(0)?.into();

block.append_operation(my_op_here);

}

Example function adding 2 arguments

Here you can view how to create a function that accepts 2 arguments:

#![allow(unused)]
fn main() {
use melior::{
    Context,
    dialect::{arith, DialectRegistry, func},
    ir::{*, attribute::{StringAttribute, TypeAttribute}, r#type::FunctionType},
    utility::register_all_dialects,
};

// We need a registry to hold all the dialects
let registry = DialectRegistry::new();
// Register all dialects that come with MLIR.
register_all_dialects(&registry);

// The MLIR context, like the LLVM one.
let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();

// A location is a debug location like in LLVM, in MLIR all
// operations need a location, even if its "unknown".
let location = Location::unknown(&context);

// A MLIR module is akin to a LLVM module.
let module = Module::new(location);

// A integer-like type with platform dependent bit width. (like size_t or usize)
// This is a type defined in the Builtin dialect.
let index_type = Type::index(&context);

// Append a `func::func` operation to the body (a block) of the module.
// This operation accepts a string attribute, which is the name.
// A type attribute, which contains a function type in this case.
// Then it accepts a single region, which is where the body
// of the function will be, this region can have
// multiple blocks, which is how you may implement
// control flow within the function.
// These blocks each can have more operations.
module.body().append_operation(func::func(
    &context,
    // accepts a StringAttribute which is the function name.
    StringAttribute::new(&context, "add"),
    // A type attribute, defining the function signature.
    TypeAttribute::new(
            FunctionType::new(&context, &[index_type, index_type], &[index_type]).into()
        ),
    {
        // The first block within the region, blocks accept arguments
        // In regions with control flow, MLIR leverages
        // this structure to implicitly represent
        // the passage of control-flow dependent values without the complex nuances
        // of PHI nodes in traditional SSA representations.
        let block = Block::new(&[(index_type, location), (index_type, location)]);

        // Use the arith dialect to add the 2 arguments.
        let sum = block.append_operation(arith::addi(
            block.argument(0).unwrap().into(),
            block.argument(1).unwrap().into(),
            location
        ));

        // Return the result using the "func" dialect return operation.
        block.append_operation(
            func::r#return( &[sum.result(0).unwrap().into()], location)
        );

        // The Func operation requires a region,
        // we add the block we created to the region and return it,
        // which is passed as an argument to the `func::func` function.
        let region = Region::new();
        region.append_block(block);
        region
    },
    &[],
    location,
));

assert!(module.as_operation().verify());
}

Workshop: Setup

Project Setup

Easy way

git clone https://github.com/lambdaclass/mlir-workshop
cd mlir-workshop
make deps
source env.sh
make build

Dependencies (manual way)

  • Rust
  • LLVM and MLIR

To install LLVM and MLIR you can do so through brew:

brew install llvm@19 (This workshop uses LLVM/MLIR 19)

brew install llvm@19
git clone https://github.com/lambdaclass/mlir-workshop
cd mlir-workshop

For melior to find the library, we need to setup some env vars (tip, you can add them to .zshenv):

export MLIR_SYS_190_PREFIX="$(brew --prefix llvm@19)"
export LLVM_SYS_191_PREFIX="$(brew --prefix llvm@19)"
export TABLEGEN_190_PREFIX="$(brew --prefix llvm@19)"

Verify you can build the project:

cargo build

Workshop: Walkthrough the prepared codebase

.
├── ast.rs // The language Abstract syntax tree.
├── codegen
│   ├── expressions.rs
│   ├── ifelse_stmt.rs
│   ├── let_stmt.rs
│   └── return_stmt.rs
├── codegen.rs // Glue code for the codegen methods.
├── grammar.lalrpop // LALRPOP grammar for parsing
├── main.rs // CLI and MLIR Context creation
└── util.rs // Code to translate MLIR to LLVM and link the binary

The workshop project already contains the code to handle the following:

  • Lexer and parser
  • CLI
  • The language AST
  • Translating to LLVM bytecode and linking the binary.

Thus what's missing is implementing the methods that "compile" the code, a.k.a emit the MLIR operations. They are located under the codegen/ folder.

The AST

The language AST is quite simple, it consists of the following:

#![allow(unused)]
fn main() {
/// The possible expressions, usually on the right hand side of an assignment
/// let x = <expr> ;
#[derive(Debug, Clone)]
pub enum Expr {
    Number(i64),
    Call { target: String, args: Vec<Expr> },
    Variable(String),
    Op(Box<Expr>, Opcode, Box<Expr>),
}

#[derive(Debug, Clone)]
pub enum Opcode {
    Mul,
    Div,
    Add,
    Sub,
    Eq,
    Neq,
}

// A statement, separated by ;
#[derive(Debug, Clone)]
pub enum Statement {
    Let(LetStmt),
    If(IfStmt),
    Return(ReturnStmt),
}

/// The let statement, it binds a value from an expression to the given variable.
#[derive(Debug, Clone)]
pub struct LetStmt {
    pub variable: String,
    pub expr: Expr,
}

/// An if with an optional else statement, depending on whether the condition evaluates to true,
/// take one or another block.
#[derive(Debug, Clone)]
pub struct IfStmt {
    pub cond: Expr,
    pub then: Block,
    pub r#else: Option<Block>,
}

/// The return statement of a function
#[derive(Debug, Clone)]
pub struct ReturnStmt {
    pub expr: Expr,
}

/// A block is a series of statements, used as the function body and if else blocks.
#[derive(Debug, Clone)]
pub struct Block {
    pub stmts: Vec<Statement>,
}

/// Describes a function, with the arguments.
/// Note: in this simple language functions always return a i64.
#[derive(Debug, Clone)]
pub struct Function {
    pub name: String,
    pub args: Vec<String>,
    pub body: Block,
}

/// The whole program, simply a list of functions.
/// The function named "main" will be the entrypoint.
#[derive(Debug, Clone)]
pub struct Program {
    pub functions: Vec<Function>,
}

}

Workshop: Compiling Expressions

To compile expressions, the following is needed:

  • Create a constant number.
  • From a variable identifier, get it's value.
  • Apply a binary operation to 2 other expressions.
#![allow(unused)]
fn main() {
// src/codegen/expressions.rs
pub fn compile_expr<'ctx: 'parent, 'parent>(
    // Helper struct with the MLIR Context and Module
    ctx: &ModuleCtx<'ctx>,
    // Hashmap storing the local variables
    locals: &HashMap<String, Value<'ctx, 'parent>>,
    // The current block to work on.
    block: &'parent Block<'ctx>,
    // The expression to compile.
    expr: &Expr,
) -> Value<'ctx, 'parent> {
    match expr {
        Expr::Number(_value) => {
            todo!("implement constant numbers")
        }
        Expr::Variable(name) => {
            todo!("implement loading values from the given variable name")
        }
        Expr::Op(lhs_expr, opcode, rhs_expr) => match opcode {
            Opcode::Mul => todo!("implement mul"),
            Opcode::Div => todo!("implement div"),
            Opcode::Add => todo!("implement add"),
            Opcode::Sub => todo!("implement sub"),
            Opcode::Eq => todo!("implement eq"),
            Opcode::Neq => todo!("implement neq"),
        },
        Expr::Call { target, args } => todo!("implement function call"),
    }
}
}

Constants in MLIR

There are various ways to create a constant, in our case, we have 2 dialects available to use:

You can find documentation about all dialects and their operations here: https://mlir.llvm.org/docs/Dialects/

It is recommended to use the arith dialect in this case.

Some useful types you will need: Type, IntegerAttribute, IntegerType, Location.

Types like IntegerType have a into() method to turn them into Type.

Loading a variable value

To make things simpler, all variables are stored inside an llvm.alloca, which is an operation that given a size gives a pointer to it. Thus, depending on the use a load/store operation is needed. This avoids dealing with Block arguments but makes the compiler rely on LLVM to optimize these allocas (which it does really well).

For this case you can use the llvm dialect to load from the pointer. The variable pointer value can be found in the given hashmap locals.

Binary operations

To iterate is human, to recurse, divine

Here you will need to use the arith dialect to compute the binary operations from computing the lhs and rhs expressions.

Workshop: Compiling Let and Assign

Let statement

On let statements, variables are declared, as explained before. In this case we need to allocate space for it and save the pointer value on the locals hashmap.

You will need to use the llvm dialect.

#![allow(unused)]
fn main() {
// src/codegen/let_stmt.rs
// let x = 2;
pub fn compile_let<'ctx: 'parent, 'parent>(
    ctx: &ModuleCtx<'ctx>,
    locals: &mut HashMap<String, Value<'ctx, 'parent>>,
    block: &'parent Block<'ctx>,
    stmt: &LetStmt,
) {
    todo!()
}
}

Assign statement

Assign is like let, but without creating the variable, only storing the updated value.

#![allow(unused)]
fn main() {
// src/codegen/let_stmt.rs
// x = 2;
pub fn compile_assign<'ctx: 'parent, 'parent>(
    ctx: &ModuleCtx<'ctx>,
    locals: &mut HashMap<String, Value<'ctx, 'parent>>,
    block: &'parent Block<'ctx>,
    stmt: &AssignStmt,
) {
    todo!("implement assign")
}
}

Workshop: Compiling Return

The return statement evaluates the expression and returns the computed value.

You will need to check the func dialect, although it is possible to do with the llvm dialect too.

#![allow(unused)]
fn main() {
pub fn compile_return<'ctx, 'parent>(
    ctx: &ModuleCtx<'ctx>,
    locals: &HashMap<String, Value>,
    block: &'parent Block<'ctx>,
    stmt: &ReturnStmt,
) {
    todo!()
}
}

Workshop: Compiling If/Else

To get simple control flow working, you will use the SCF dialect. With this dialect you don't need to add extra blocks, since the control flow will be contained within the regions inside the SCF operations.

The only limitation is that we can't do early returns this way, but for this simple language it won't matter.

You will need to clone the locals HashMap inside the created regions to avoid lifetime issues. But since any variable created inside the if or else block only live for that scope, it works well.

#![allow(unused)]
fn main() {
// src/codegen/ifelse_stmt.rs
pub fn compile_if<'ctx, 'parent>(
    ctx: &ModuleCtx<'ctx>,
    locals: &mut HashMap<String, Value<'ctx, 'parent>>,
    block: &'parent Block<'ctx>,
    stmt: &IfStmt,
) {
    todo!()
}
}

Workshop: Compiling Function calls

#![allow(unused)]
fn main() {
// src/codegen/expressions.rs
Expr::Call { target, args } => todo!("implement function call"),
}

Since all arguments are of the same type, and for simplicity sake we don't verify the number of arguments matches the function this should be relatively simple using the func dialect.

Workshop: Compiling Functions

Now to wrap up the function itself needs to be created, using the func dialect and adding it to the module body() block. (The module is available under the ctx variable.)

You also need to allocate space for the arguments, and store the value there. You can get the value from the block arguments.

Remember that in this language functions always return a i64 value.

Some useful types you will need: Type, IntegerAttribute, IntegerType, FunctionType, TypeAttribute, StringAttribute.

#![allow(unused)]
fn main() {
// src/codegen.rs:60+
fn compile_function(ctx: &ModuleCtx<'_>, func: &Function) {
    let mut args: Vec<(Type, Location)> = vec![];
    let mut func_args: Vec<Type> = Vec::new();

    for _ in &func.args {
        args.push((
            IntegerType::new(ctx.ctx, 64).into(),
            Location::unknown(ctx.ctx),
        ));
        func_args.push(IntegerType::new(ctx.ctx, 64).into());
    }

    let region = Region::new();
    let block = region.append_block(Block::new(&args));
    let mut locals: HashMap<String, Value> = HashMap::new();

    // Allocate space for the arguments, get them from the block, storing them and save them on locals hashmap.

    for stmt in &func.body.stmts {
        compile_statement(ctx, &mut locals, &block, stmt);
    }

    // Create the func operation here.
}
}

Workshop: Testing

Now that the functionality is implemented, you can run the tests included in the repo, to do so you can run:

cargo test

The tests are programs under the test/ directory, they are functions with a well defined name and signature, so we can easily call them from Rust, using the C call convention. The tests are run using the LLVM JIT engine.

Workshop: Glue code

Here the glue code is explained, mostly how the lowering and compilation works.

Initial steps

First the MLIR context and the registry of dialects needs to be initialized, then we add and load the dialects into the context.

#![allow(unused)]
fn main() {
// src/codegen.rs
pub fn compile_program(program: &Program, optlevel: OptLevel, out_name: &Path) {
 // We need a registry to hold all the dialects
    let registry = DialectRegistry::new();
    // Register all dialects that come with MLIR.
    register_all_dialects(&registry);
    let context = Context::new();
    context.append_dialect_registry(&registry);
    context.load_all_available_dialects();
    // ...
}
}

Next, initialize the Module and for ease, put both the context and module in a struct. Afterwards we compile all the functions.

#![allow(unused)]
fn main() {
let mut module = Module::new(Location::unknown(&context));
    let ctx = ModuleCtx {
        ctx: &context,
        module: &module,
    };

    for func in &program.functions {
        compile_function(&ctx, func);
    }
}

Now, the module contains operations from the various used dialects, we need to convert them all to the LLVM dialect to compile it with LLVM, to do so the PassManager is needed, adding the necessary passes to transform and convert the dialects.

#![allow(unused)]
fn main() {
// Run passes on module to convert all dialects to LLVM.
let pass_manager = PassManager::new(&context);
pass_manager.enable_verifier(true);
pass_manager.add_pass(pass::transform::create_canonicalizer());
pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); // needed because to_llvm doesn't include it.
pass_manager.add_pass(pass::conversion::create_to_llvm());
pass_manager.run(&mut module).unwrap();
}

End

If you reached this by following all the steps, you should now have a (really) minimal working language, well done!

Some projects to check out

  • Cairo native makes extensive use of MLIR.
  • Concrete first codegen backend is made with MLIR, this should be easier to read than Cairo Native code.