Skip to content

Commit 9615fc9

Browse files
committed
feat: add aws-rft-sdk package for Reinforcement Fine-Tuning
Add a standalone SDK package that integrates SageMaker RFT (Reinforcement Fine-Tuning) with Strands agent framework. Provides: - RolloutFeedbackClient: report rewards back to the training service - @rft_handler: decorator to extract rollout metadata from payloads - RFTContext: thread-local context for propagating training metadata - wrap_model: Strands model adapter that injects X-RFT-* headers
1 parent 71c8d70 commit 9615fc9

File tree

7 files changed

+308
-0
lines changed

7 files changed

+308
-0
lines changed

aws-rft-sdk/pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[build-system]
2+
requires = ["hatchling"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "aws-rft-sdk"
7+
version = "0.1.0"
8+
description = "AWS Reinforcement Fine-Tuning SDK for online rollout-based training"
9+
readme = {text = "", content-type = "text/markdown"}
10+
requires-python = ">=3.9"
11+
dependencies = [
12+
"boto3>=1.35.0",
13+
]
14+
15+
[project.optional-dependencies]
16+
strands = ["strands-agents>=0.1.0"]
17+
18+
[tool.hatch.build.targets.wheel]
19+
packages = ["src/aws_rft_sdk"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from aws_rft_sdk.client import RolloutFeedbackClient
2+
from aws_rft_sdk.handler import rft_handler
3+
from aws_rft_sdk.context import RFTContext
4+
5+
__all__ = ["RolloutFeedbackClient", "rft_handler", "RFTContext"]

aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Strands model adapter — wraps a Strands model to inject RFT headers.
2+
3+
Usage::
4+
5+
from aws_rft_sdk.adapters.strands import wrap_model
6+
from strands.models.openai import OpenAIModel
7+
8+
model = OpenAIModel(
9+
client_args={"api_key": key, "base_url": endpoint},
10+
model_id="my-model",
11+
)
12+
model = wrap_model(model) # Now injects X-RFT-* headers on every call
13+
14+
Requires the Strands OpenAIModel to pass through ``extra_headers`` kwarg
15+
to the underlying OpenAI client (supported since strands-agents >= X.Y.Z).
16+
"""
17+
18+
import logging
19+
from typing import Any
20+
21+
from aws_rft_sdk.context import RFTContext
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def wrap_model(model: Any) -> Any:
27+
"""Wrap a Strands model to automatically inject RFT training headers.
28+
29+
The wrapper reads the current rollout context (populated by ``@rft_handler``)
30+
and adds ``X-RFT-*`` headers to every inference request so the training
31+
inference endpoint can correlate requests with rollouts.
32+
33+
Args:
34+
model: A Strands model instance (e.g., ``OpenAIModel``).
35+
36+
Returns:
37+
A wrapped model that transparently injects RFT headers.
38+
"""
39+
return _RFTModelWrapper(model)
40+
41+
42+
class _RFTModelWrapper:
43+
"""Transparent proxy that injects RFT headers into Strands model calls.
44+
45+
Delegates all attribute access to the inner model so it quacks like
46+
the original. Intercepts ``stream()`` to inject ``extra_headers``.
47+
"""
48+
49+
def __init__(self, inner_model: Any):
50+
object.__setattr__(self, "_inner", inner_model)
51+
52+
def __getattr__(self, name: str) -> Any:
53+
return getattr(self._inner, name)
54+
55+
def __setattr__(self, name: str, value: Any):
56+
if name == "_inner":
57+
object.__setattr__(self, name, value)
58+
else:
59+
setattr(self._inner, name, value)
60+
61+
def stream(self, *args: Any, **kwargs: Any) -> Any:
62+
"""Intercept stream() to inject RFT headers via extra_headers kwarg."""
63+
rft_headers = RFTContext.get_headers()
64+
if rft_headers:
65+
existing = kwargs.get("extra_headers") or {}
66+
existing.update(rft_headers)
67+
kwargs["extra_headers"] = existing
68+
logger.debug("Injected RFT headers: %s", list(rft_headers.keys()))
69+
return self._inner.stream(*args, **kwargs)
70+
71+
def update_config(self, **model_config: Any) -> None:
72+
return self._inner.update_config(**model_config)
73+
74+
def get_config(self) -> Any:
75+
return self._inner.get_config()
76+
77+
def structured_output(self, *args: Any, **kwargs: Any) -> Any:
78+
return self._inner.structured_output(*args, **kwargs)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""RolloutFeedbackClient — reports rewards and completion status back to the training service."""
2+
3+
import logging
4+
from typing import Optional
5+
6+
import boto3
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class RolloutFeedbackClient:
12+
"""Client for reporting rollout feedback (rewards) to the RFT training service.
13+
14+
Typically used inside an @rft_handler-decorated entrypoint to report
15+
the reward computed from the agent's rollout.
16+
17+
Example::
18+
19+
from aws_rft_sdk import RolloutFeedbackClient
20+
21+
client = RolloutFeedbackClient(payload.get("metadata"))
22+
client.report_complete(reward=0.85)
23+
24+
Args:
25+
metadata: The ``metadata`` dict from the rollout payload. Contains
26+
training_job_arn, rollout_id, feedback_endpoint, etc.
27+
"""
28+
29+
def __init__(self, metadata: dict):
30+
self._metadata = metadata or {}
31+
self._training_job_arn = self._metadata.get("training_job_arn")
32+
self._rollout_id = self._metadata.get("rollout_id")
33+
self._feedback_endpoint = self._metadata.get("feedback_endpoint")
34+
self._client = None
35+
36+
def _get_client(self):
37+
if self._client is None:
38+
kwargs = {}
39+
if self._feedback_endpoint:
40+
kwargs["endpoint_url"] = self._feedback_endpoint
41+
self._client = boto3.client("sagemaker", **kwargs)
42+
return self._client
43+
44+
def report_complete(self, reward: float):
45+
"""Report successful rollout completion with a reward score.
46+
47+
Args:
48+
reward: The computed reward for this rollout (typically 0.0–1.0).
49+
"""
50+
logger.info(
51+
"Reporting rollout complete: training_job=%s rollout=%s reward=%s",
52+
self._training_job_arn,
53+
self._rollout_id,
54+
reward,
55+
)
56+
# TODO: Replace with actual RFT feedback API call when available.
57+
# The service API will accept:
58+
# - TrainingJobArn
59+
# - RolloutId
60+
# - Reward (float)
61+
# - Status (COMPLETED)
62+
client = self._get_client()
63+
# Placeholder — actual API TBD
64+
# client.send_rollout_feedback(
65+
# TrainingJobArn=self._training_job_arn,
66+
# RolloutId=self._rollout_id,
67+
# Reward=reward,
68+
# Status="COMPLETED",
69+
# )
70+
71+
def report_error(self, error: str, reward: Optional[float] = None):
72+
"""Report a rollout error.
73+
74+
Args:
75+
error: Error description.
76+
reward: Optional partial reward (defaults to 0.0).
77+
"""
78+
logger.error(
79+
"Reporting rollout error: training_job=%s rollout=%s error=%s",
80+
self._training_job_arn,
81+
self._rollout_id,
82+
error,
83+
)
84+
# TODO: Replace with actual RFT feedback API call when available.
85+
# client.send_rollout_feedback(
86+
# TrainingJobArn=self._training_job_arn,
87+
# RolloutId=self._rollout_id,
88+
# Reward=reward or 0.0,
89+
# Status="FAILED",
90+
# ErrorMessage=error,
91+
# )
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Thread-local context for RFT rollout metadata.
2+
3+
The rft_handler decorator populates this context from the payload metadata.
4+
The Strands model wrapper reads it to inject per-request headers.
5+
"""
6+
7+
import threading
8+
from typing import Optional
9+
10+
_context = threading.local()
11+
12+
13+
class RFTContext:
14+
"""Access the current RFT rollout context.
15+
16+
Set by @rft_handler, read by wrap_model adapters to inject headers.
17+
"""
18+
19+
@staticmethod
20+
def get_headers() -> dict:
21+
"""Return HTTP headers for the current rollout context."""
22+
metadata = getattr(_context, "metadata", None)
23+
if metadata is None:
24+
return {}
25+
headers = {}
26+
if metadata.get("training_job_arn"):
27+
headers["X-RFT-Training-Job-Arn"] = metadata["training_job_arn"]
28+
if metadata.get("rollout_id"):
29+
headers["X-RFT-Rollout-Id"] = metadata["rollout_id"]
30+
if metadata.get("episode_id"):
31+
headers["X-RFT-Episode-Id"] = metadata["episode_id"]
32+
return headers
33+
34+
@staticmethod
35+
def get_metadata() -> Optional[dict]:
36+
"""Return the raw metadata dict, or None if not in an RFT context."""
37+
return getattr(_context, "metadata", None)
38+
39+
40+
def _set_metadata(metadata: dict):
41+
_context.metadata = metadata
42+
43+
44+
def _clear_metadata():
45+
_context.metadata = None
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""@rft_handler decorator — wraps an entrypoint to manage RFT rollout context."""
2+
3+
import asyncio
4+
import functools
5+
import inspect
6+
import logging
7+
8+
from aws_rft_sdk.client import RolloutFeedbackClient
9+
from aws_rft_sdk.context import _set_metadata, _clear_metadata
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def rft_handler(func):
15+
"""Decorator that sets up RFT rollout context around an entrypoint.
16+
17+
Extracts ``metadata`` from the payload, makes it available via
18+
``RFTContext.get_headers()`` (used by ``wrap_model``), and auto-reports
19+
errors if the function raises.
20+
21+
Works with both sync and async functions.
22+
23+
Example::
24+
25+
@app.entrypoint
26+
@rft_handler
27+
async def invoke_agent(payload):
28+
user_input = payload.get("instance")
29+
response = await agent.invoke_async(user_input)
30+
return response.message["content"][0]["text"]
31+
"""
32+
33+
if asyncio.iscoroutinefunction(func):
34+
35+
@functools.wraps(func)
36+
async def async_wrapper(payload, *args, **kwargs):
37+
metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
38+
_set_metadata(metadata)
39+
try:
40+
return await func(payload, *args, **kwargs)
41+
except Exception as e:
42+
logger.error("RFT rollout failed: %s", e)
43+
try:
44+
RolloutFeedbackClient(metadata).report_error(str(e))
45+
except Exception:
46+
logger.exception("Failed to report rollout error")
47+
raise
48+
finally:
49+
_clear_metadata()
50+
51+
return async_wrapper
52+
else:
53+
54+
@functools.wraps(func)
55+
def sync_wrapper(payload, *args, **kwargs):
56+
metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
57+
_set_metadata(metadata)
58+
try:
59+
return func(payload, *args, **kwargs)
60+
except Exception as e:
61+
logger.error("RFT rollout failed: %s", e)
62+
try:
63+
RolloutFeedbackClient(metadata).report_error(str(e))
64+
except Exception:
65+
logger.exception("Failed to report rollout error")
66+
raise
67+
finally:
68+
_clear_metadata()
69+
70+
return sync_wrapper

0 commit comments

Comments
 (0)