Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions examples/coupled_trial_generator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,58 @@
import logging
import random

import numpy as np

from aind_behavior_dynamic_foraging.task_logic.trial_generators.coupled_trial_generator import CoupledTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome


def simulate_response(
previous_reward: bool, previous_choice: bool | None, previous_left_bait: bool, previous_right_bait: bool
) -> TrialOutcome:

np.random.seed(42)

if np.random.random(1) < 0.1:
is_right_choice = None
elif np.random.random(1) < 0:
is_right_choice = False
elif previous_choice is None:
is_right_choice = random.choice([True, False])
else:
is_right_choice = previous_choice if previous_reward else not previous_choice

if is_right_choice is None:
is_rewarded = False
else:
is_rewarded = previous_right_bait if is_right_choice else previous_left_bait

return TrialOutcome(trial=Trial(), is_right_choice=is_right_choice, is_rewarded=is_rewarded)


def main():
coupled_trial_generator = CoupledTrialGeneratorSpec().create_generator()
trial = Trial()
outcome = TrialOutcome(
trial=trial, is_right_choice=random.choice([True, False, None]), is_rewarded=random.choice([True, False])
)
for i in range(100):
trial_outcome = TrialOutcome(
trial=trial, is_right_choice=random.choice([True, False, None]), is_rewarded=random.choice([True, False])
)
coupled_trial_generator.update(trial_outcome)
trial = coupled_trial_generator.next()
coupled_trial_generator.update(outcome)
outcome = simulate_response(
previous_reward=outcome.is_rewarded,
previous_choice=outcome.is_right_choice,
previous_left_bait=False,
previous_right_bait=False,
)

if not trial:
print("Session finished")
return

print(f"Next trial: {trial}")


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
main()
57 changes: 57 additions & 0 deletions examples/warmup_trial_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import random

import numpy as np

from aind_behavior_dynamic_foraging.task_logic.trial_generators import WarmupTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome


def simulate_response(
previous_reward: bool, previous_choice: bool | None, previous_left_bait: bool, previous_right_bait: bool
) -> TrialOutcome:

np.random.seed(42)

if np.random.random(1) < 0.1:
is_right_choice = None
elif np.random.random(1) < 0:
is_right_choice = False
elif previous_choice is None:
is_right_choice = random.choice([True, False])
else:
is_right_choice = previous_choice if previous_reward else not previous_choice

if is_right_choice is None:
is_rewarded = False
else:
is_rewarded = previous_right_bait if is_right_choice else previous_left_bait

return TrialOutcome(trial=Trial(), is_right_choice=is_right_choice, is_rewarded=is_rewarded)


def main():
warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator()
trial = Trial()
outcome = TrialOutcome(
trial=trial, is_right_choice=random.choice([True, False, None]), is_rewarded=random.choice([True, False])
)
for i in range(100):
trial = warmup_trial_generator.next()
warmup_trial_generator.update(outcome)
outcome = simulate_response(
previous_reward=outcome.is_rewarded,
previous_choice=outcome.is_right_choice,
previous_left_bait=warmup_trial_generator.is_left_baited,
previous_right_bait=warmup_trial_generator.is_right_baited,
)
if not trial:
print("Warmup finished")
return

print(f"Next trial: {trial}")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
main()
161 changes: 160 additions & 1 deletion schema/aind_behavior_dynamic_foraging.json
Original file line number Diff line number Diff line change
Expand Up @@ -2824,11 +2824,15 @@
"mapping": {
"CoupledTrialGenerator": "#/$defs/CoupledTrialGeneratorSpec",
"IntegrationTestTrialGenerator": "#/$defs/IntegrationTestTrialGeneratorSpec",
"TrialGeneratorComposite": "#/$defs/TrialGeneratorCompositeSpec_TrialGeneratorSpec_"
"TrialGeneratorComposite": "#/$defs/TrialGeneratorCompositeSpec_TrialGeneratorSpec_",
"WarmupTrialGenerator": "#/$defs/WarmupTrialGeneratorSpec"
},
"propertyName": "type"
},
"oneOf": [
{
"$ref": "#/$defs/WarmupTrialGeneratorSpec"
},
{
"$ref": "#/$defs/CoupledTrialGeneratorSpec"
},
Expand Down Expand Up @@ -3060,6 +3064,161 @@
"title": "VideoWriterOpenCv",
"type": "object"
},
"WarmupTrialGenerationEndConditions": {
"properties": {
"min_trial": {
"default": 50,
"description": "Minimum trials in generator.",
"minimum": 0,
"title": "Min Trial",
"type": "integer"
},
"max_choice_bias": {
"default": 0.1,
"description": "Maximum allowed deviation from 50/50 choice ratio to end trial generation.",
"maximum": 1,
"minimum": 0,
"title": "Max Choice Bias",
"type": "number"
},
"min_response_rate": {
"default": 0.8,
"description": "Minimum fraction of trials with a choice (non-ignored) to end trial generation.",
"maximum": 1,
"minimum": 0,
"title": "Min Response Rate",
"type": "number"
},
"evaluation_window": {
"default": 20,
"description": "Number of most recent trials to evaluate the end criteria.",
"minimum": 0,
"title": "Evaluation Window",
"type": "integer"
}
},
"title": "WarmupTrialGenerationEndConditions",
"type": "object"
},
"WarmupTrialGeneratorSpec": {
"properties": {
"type": {
"const": "WarmupTrialGenerator",
"default": "WarmupTrialGenerator",
"title": "Type",
"type": "string"
},
"quiescent_duration": {
"$ref": "#/$defs/Distribution",
"default": {
"family": "Exponential",
"distribution_parameters": {
"family": "Exponential",
"rate": 1.0
},
"truncation_parameters": {
"max": 1.0,
"min": 0.0,
"truncation_mode": "exclude"
},
"scaling_parameters": null
},
"description": "Distribution describing the quiescence period before trial starts (in seconds). Each lick resets the timer."
},
"response_duration": {
"default": 1.0,
"description": "Duration after go cue for animal response.",
"minimum": 0,
"title": "Response Duration",
"type": "number"
},
"reward_consumption_duration": {
"default": 3.0,
"description": "Duration of reward consumption before transition to ITI (in seconds).",
"minimum": 0,
"title": "Reward Consumption Duration",
"type": "number"
},
"inter_trial_interval_duration": {
"$ref": "#/$defs/Distribution",
"default": {
"family": "Exponential",
"distribution_parameters": {
"family": "Exponential",
"rate": 0.5
},
"truncation_parameters": {
"max": 8.0,
"min": 1.0,
"truncation_mode": "exclude"
},
"scaling_parameters": null
},
"description": "Distribution describing the inter-trial interval (in seconds)."
},
"block_len": {
"$ref": "#/$defs/ExponentialDistribution",
"default": {
"family": "Exponential",
"distribution_parameters": {
"family": "Exponential",
"rate": 1.0
},
"truncation_parameters": {
"max": 2.0,
"min": 1.0,
"truncation_mode": "exclude"
},
"scaling_parameters": null
},
"description": "Distribution describing block length."
},
"min_block_reward": {
"const": 1,
"default": 1,
"title": "Minimal rewards in a block to switch",
"type": "integer"
},
"kernel_size": {
"default": 2,
"description": "Kernel to evaluate choice fraction.",
"title": "Kernel Size",
"type": "integer"
},
"reward_probability_parameters": {
"$ref": "#/$defs/RewardProbabilityParameters",
"default": {
"base_reward_sum": 0.8,
"reward_pairs": [
[
8.0,
1.0
]
]
},
"description": "Parameters defining the reward probability structure."
},
"is_baiting": {
"const": true,
"default": true,
"description": "Whether uncollected rewards carry over to the next trial.",
"title": "Is Baiting",
"type": "integer"
},
"trial_generation_end_parameters": {
"$ref": "#/$defs/WarmupTrialGenerationEndConditions",
"default": {
"min_trial": 50,
"max_choice_bias": 0.1,
"min_response_rate": 0.8,
"evaluation_window": 20
},
"description": "Conditions to end trial generation."
}
},
"title": "WarmupTrialGeneratorSpec",
"type": "object"
},
"WaterValveCalibration": {
"description": "Represents a water valve calibration.",
"properties": {
Expand Down
Loading