Anirudh Esthuri commited on
Commit
4a34f6e
Β·
1 Parent(s): d349f76

Add inference profile ARNs for provisioned throughput Claude models

Browse files
Files changed (2) hide show
  1. llm.py +31 -13
  2. model_config.py +7 -0
llm.py CHANGED
@@ -6,7 +6,7 @@ import boto3
6
  import openai
7
  import requests
8
  from dotenv import load_dotenv
9
- from model_config import MODEL_TO_PROVIDER
10
 
11
  # ──────────────────────────────────────────────────────────────
12
  # Load environment variables
@@ -119,11 +119,13 @@ def chat(messages, persona):
119
 
120
  try:
121
  bedrock_runtime = get_bedrock_client()
122
- response = bedrock_runtime.invoke_model(
123
- modelId=MODEL_STRING,
124
- contentType="application/json",
125
- accept="application/json",
126
- body=json.dumps(
 
 
127
  {
128
  "anthropic_version": "bedrock-2023-05-31",
129
  "system": system_prompt,
@@ -132,7 +134,15 @@ def chat(messages, persona):
132
  "temperature": 0.3, # Lower temperature for more focused responses
133
  }
134
  ),
135
- )
 
 
 
 
 
 
 
 
136
 
137
  dt = time.time() - t0
138
  body = json.loads(response["body"].read())
@@ -374,17 +384,25 @@ def check_credentials():
374
  try:
375
  bedrock_runtime = get_bedrock_client()
376
  # Try a simple test invocation to verify credentials
377
- test_response = bedrock_runtime.invoke_model(
378
- modelId="anthropic.claude-haiku-4-5-20251001-v1:0",
379
- contentType="application/json",
380
- accept="application/json",
381
- body=json.dumps({
382
  "anthropic_version": "bedrock-2023-05-31",
383
  "messages": [{"role": "user", "content": "test"}],
384
  "max_tokens": 10,
385
  "temperature": 0.1
386
  })
387
- )
 
 
 
 
 
 
 
 
388
  print("Bedrock connection successful")
389
  return True
390
  except Exception as e:
 
6
  import openai
7
  import requests
8
  from dotenv import load_dotenv
9
+ from model_config import MODEL_TO_PROVIDER, MODEL_TO_INFERENCE_PROFILE_ARN
10
 
11
  # ──────────────────────────────────────────────────────────────
12
  # Load environment variables
 
119
 
120
  try:
121
  bedrock_runtime = get_bedrock_client()
122
+
123
+ # Use inference profile ARN if available (for provisioned throughput models)
124
+ # Otherwise use modelId (for on-demand models)
125
+ invoke_kwargs = {
126
+ "contentType": "application/json",
127
+ "accept": "application/json",
128
+ "body": json.dumps(
129
  {
130
  "anthropic_version": "bedrock-2023-05-31",
131
  "system": system_prompt,
 
134
  "temperature": 0.3, # Lower temperature for more focused responses
135
  }
136
  ),
137
+ }
138
+
139
+ # Check if this model has an inference profile ARN (provisioned throughput)
140
+ if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
141
+ invoke_kwargs["inferenceProfileIdentifier"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
142
+ else:
143
+ invoke_kwargs["modelId"] = MODEL_STRING
144
+
145
+ response = bedrock_runtime.invoke_model(**invoke_kwargs)
146
 
147
  dt = time.time() - t0
148
  body = json.loads(response["body"].read())
 
384
  try:
385
  bedrock_runtime = get_bedrock_client()
386
  # Try a simple test invocation to verify credentials
387
+ test_model = "anthropic.claude-haiku-4-5-20251001-v1:0"
388
+ test_kwargs = {
389
+ "contentType": "application/json",
390
+ "accept": "application/json",
391
+ "body": json.dumps({
392
  "anthropic_version": "bedrock-2023-05-31",
393
  "messages": [{"role": "user", "content": "test"}],
394
  "max_tokens": 10,
395
  "temperature": 0.1
396
  })
397
+ }
398
+
399
+ # Use inference profile ARN if available
400
+ if test_model in MODEL_TO_INFERENCE_PROFILE_ARN:
401
+ test_kwargs["inferenceProfileIdentifier"] = MODEL_TO_INFERENCE_PROFILE_ARN[test_model]
402
+ else:
403
+ test_kwargs["modelId"] = test_model
404
+
405
+ test_response = bedrock_runtime.invoke_model(**test_kwargs)
406
  print("Bedrock connection successful")
407
  return True
408
  except Exception as e:
model_config.py CHANGED
@@ -31,3 +31,10 @@ MODEL_DISPLAY_NAMES = {
31
  }
32
 
33
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
 
 
 
 
 
 
 
 
31
  }
32
 
33
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
34
+
35
+ # Inference profile ARNs for provisioned throughput models
36
+ MODEL_TO_INFERENCE_PROFILE_ARN = {
37
+ "anthropic.claude-haiku-4-5-20251001-v1:0": "arn:aws:bedrock:us-east-1:850995563530:inference-profile/global.anthropic.claude-haiku-4-5-20251001-v1:0",
38
+ "anthropic.claude-sonnet-4-5-20250929-v1:0": "arn:aws:bedrock:us-east-1:850995563530:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0",
39
+ "anthropic.claude-opus-4-20250514-v1:0": "arn:aws:bedrock:us-east-1:850995563530:inference-profile/global.anthropic.claude-sonnet-4-20250514-v1:0",
40
+ }