Enforcing Guardrails on Choice Selection
Note
To download this tutorial as a Jupyter notebook, click here.
In this example, we want the LLM to pick an action (e.g. fight
or flight
), and based on that action we want to return different JSON objects. For example, if the action is fight
, we want to return a JSON object that contains the weapon
field. If the action is flight
, we want to return a JSON object that contains the direction
and distance
fields.
We make the assumption that:
- We don't need any external libraries that are not already installed in the environment.
- We are able to execute the code in the environment.
Objective
We want the LLM to play an RP game where it can choose to either fight
or flight
. If it chooses to fight
, the LLM should choose a weapon
and an enemy
. If the player chooses flight
, the LLM should choose a direction
and a distance
.
Step 1: Generating RAIL
Spec
Ordinarily, we could create a separate RAIL
spec in a file. However, for the sake of this example, we will generate the RAIL
spec in the notebook as a string or a Pydantic Model.
XML option:
rail_str = """
<rail version="0.1">
<output>
<choice discriminator="chosen_action" name="action" on-fail-choice="reask">
<case name="fight">
<string format="valid-choices: {['crossbow', 'machine gun']}" name="weapon" on-fail-valid-choices="reask"></string>
</case>
<case name="flight">
<string format="valid-choices: {['north','south','east','west']}" name="flight_direction" on-fail-valid-choices="exception"></string>
<integer format="valid-choices: {[1,2,3,4]}" name="distance" on-fail-valid-choices="exception"></integer>
</case>
</choice>
</output>
<prompt>
You are a human in an enchanted forest. You come across opponents of different types, and you should fight smaller opponents and run away from bigger ones.
You run into a ${opp_type}. What do you do?
${gr.complete_json_suffix_v2}</prompt>
</rail>
"""
Pydantic model option:
from guardrails.validators import ValidChoices
from pydantic import BaseModel, Field
from typing import Literal, Union
prompt = """
You are a human in an enchanted forest. You come across opponents of different types, and you should fight smaller opponents and run away from bigger ones.
You run into a ${opp_type}. What do you do?
${gr.complete_json_suffix_v2}"""
class Fight(BaseModel):
chosen_action: Literal['fight']
weapon: str = Field(validators=[ValidChoices(['crossbow', 'machine gun'], on_fail="reask")])
class Flight(BaseModel):
chosen_action: Literal['flight']
flight_direction: str = Field(validators=[ValidChoices(['north','south','east','west'], on_fail="exception")])
distance: int = Field(validators=[ValidChoices([1,2,3,4], on_fail="exception")])
class FightOrFlight(BaseModel):
action: Union[Fight, Flight] = Field(discriminator='chosen_action')
Step 2: Create a Guard
object with the RAIL Spec
We create a gd.Guard
object that will check, validate and correct the generated code. This object:
- Enforces the quality criteria specified in the RAIL spec (i.e. bug free code).
- Takes corrective action when the quality criteria are not met (i.e. reasking the LLM).
- Compiles the schema and type info from the RAIL spec and adds it to the prompt.
From XML:
Or from Pydantic:
The Guard
object compiles the output schema and adds it to the prompt. We can see the final prompt below:
Step 3: Wrap the LLM API call with Guard
We can now wrap the LLM API call with the Guard
object. This will ensure that the LLM generates an output that is compliant with the RAIL spec.
To start, we test with a 'giant' as an opponent, and look at the output.
Running the cell above returns:
1. The raw LLM text output as a single string.
2. A dictionary where the key is python_code
and the value is the generated code.
We can see that if the LLM chooses flight
, the output is a dictionary with flight_direction
and distance
fields.
We can inspect the logs of the guard object to see the quality criteria that were checked and the corrective actions that were taken.
Now, let's test with a goblin
as an opponent.
We can see that the LLM chose to fight
and the output is a choice of weapon
.
We can inspect the state of the guard after each call to see what happened.