Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ public interface DecisionLineRepository extends JpaRepository<DecisionLine, Long
@EntityGraph(attributePaths = {"user"})
Optional<DecisionLine> findWithUserById(Long id);

@EntityGraph(attributePaths = {"user", "baseLine", "baseLine.baseNodes"})
Optional<DecisionLine> findWithUserAndBaseLineById(Long id);

void deleteByBaseLine_Id(Long baseLineId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ private ScenarioValidationResult validateScenarioCreation(
ScenarioCreateRequest request,
@Nullable DecisionNodeNextRequest lastDecision) {

// DecisionLine 존재 여부 확인 (User EAGER 로딩)
DecisionLine decisionLine = decisionLineRepository.findWithUserById(request.decisionLineId())
// DecisionLine 존재 여부 확인 (User, BaseLine, BaseLine.baseNodes EAGER 로딩)
DecisionLine decisionLine = decisionLineRepository.findWithUserAndBaseLineById(request.decisionLineId())
.orElseThrow(() -> new ApiException(ErrorCode.DECISION_LINE_NOT_FOUND));

// 권한 검증
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public CompletableFuture<String> generateText(String prompt) {

@Override
public CompletableFuture<String> generateText(AiRequest aiRequest) {
log.info("[CLIENT] GeminiJsonTextClient (2.0) is being used.");
if (aiRequest == null || aiRequest.prompt() == null) {
return CompletableFuture.failedFuture(new AiParsingException("Prompt is null"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ public CompletableFuture<String> generateText(String prompt) {

@Override
public CompletableFuture<String> generateText(AiRequest aiRequest) {
log.info("[CLIENT] GeminiTextClient (2.5) is being used.");
return webClient
.post()
.uri("/v1beta/models/{model}:generateContent", textAiConfig.getModel())
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(createGeminiRequest(aiRequest.prompt(), aiRequest.maxTokens()))
.bodyValue(createGeminiRequest(aiRequest))
.retrieve()
.onStatus(HttpStatusCode::isError, this::handleErrorResponse)
.bodyToMono(GeminiResponse.class)
Expand All @@ -64,17 +65,17 @@ public CompletableFuture<String> generateText(AiRequest aiRequest) {
.toFuture();
}

private Map<String, Object> createGeminiRequest(String prompt, int maxTokens) {
private Map<String, Object> createGeminiRequest(AiRequest aiRequest) {
// AiRequest로부터 generationConfig를 가져와 사용
java.util.Map<String, Object> generationConfig = new java.util.HashMap<>(aiRequest.parameters());
// maxTokens는 AiRequest의 전용 필드에서 가져와 확실히 설정
generationConfig.put("maxOutputTokens", aiRequest.maxTokens());

return Map.of(
"contents", List.of(
Map.of("parts", List.of(Map.of("text", prompt)))
Map.of("parts", List.of(Map.of("text", aiRequest.prompt())))
),
"generationConfig", Map.of(
"temperature", 0.8, // 시나리오 생성용 창의성 향상 (0.7 → 0.8)
"topK", 3, // 성능 최적화 (40 → 3, 10-15% 속도 향상)
"topP", 0.95,
"maxOutputTokens", maxTokens // AiRequest의 maxTokens 사용
)
"generationConfig", generationConfig
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,21 @@ public class BaseScenarioAiProperties {
private int maxOutputTokens = 1000;
private int timeoutSeconds = 60;

// Generation Config (AI 응답 품질 제어)
private double temperature = 0.7;
private double topP = 0.9;
private int topK = 40;

// getters/setters
public int getMaxOutputTokens() { return maxOutputTokens; }
public void setMaxOutputTokens(int maxOutputTokens) { this.maxOutputTokens = maxOutputTokens; }
public int getTimeoutSeconds() { return timeoutSeconds; }
public void setTimeoutSeconds(int timeoutSeconds) { this.timeoutSeconds = timeoutSeconds; }

public double getTemperature() { return temperature; }
public void setTemperature(double temperature) { this.temperature = temperature; }
public double getTopP() { return topP; }
public void setTopP(double topP) { this.topP = topP; }
public int getTopK() { return topK; }
public void setTopK(int topK) { this.topK = topK; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,21 @@ public class DecisionScenarioAiProperties {
private int maxOutputTokens = 1200;
private int timeoutSeconds = 60;

// Generation Config (AI 응답 품질 제어)
private double temperature = 0.7;
private double topP = 0.9;
private int topK = 40;

// getters/setters
public int getMaxOutputTokens() { return maxOutputTokens; }
public void setMaxOutputTokens(int maxOutputTokens) { this.maxOutputTokens = maxOutputTokens; }
public int getTimeoutSeconds() { return timeoutSeconds; }
public void setTimeoutSeconds(int timeoutSeconds) { this.timeoutSeconds = timeoutSeconds; }

public double getTemperature() { return temperature; }
public void setTemperature(double temperature) { this.temperature = temperature; }
public double getTopP() { return topP; }
public void setTopP(double topP) { this.topP = topP; }
public int getTopK() { return topK; }
public void setTopK(int topK) { this.topK = topK; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public record Indicator(
* indicators 배열을 Map<Type, Integer>로 변환
*/
public Map<Type, Integer> indicatorScores() {
if (indicators == null) {
return java.util.Collections.emptyMap();
}
return indicators.stream()
.collect(java.util.stream.Collectors.toMap(
ind -> Type.valueOf(ind.type),
Expand All @@ -42,6 +45,9 @@ public Map<Type, Integer> indicatorScores() {
* indicators 배열을 Map<Type, String>로 변환
*/
public Map<Type, String> indicatorAnalysis() {
if (indicators == null) {
return java.util.Collections.emptyMap();
}
return indicators.stream()
.collect(java.util.stream.Collectors.toMap(
ind -> Type.valueOf(ind.type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public record Comparison(
* indicators 배열을 Map<Type, Integer>로 변환
*/
public Map<Type, Integer> indicatorScores() {
if (indicators == null) {
return java.util.Collections.emptyMap();
}
return indicators.stream()
.collect(java.util.stream.Collectors.toMap(
ind -> Type.valueOf(ind.type),
Expand All @@ -53,6 +56,9 @@ public Map<Type, Integer> indicatorScores() {
* indicators 배열을 Map<Type, String>로 변환
*/
public Map<Type, String> indicatorAnalysis() {
if (indicators == null) {
return java.util.Collections.emptyMap();
}
return indicators.stream()
.collect(java.util.stream.Collectors.toMap(
ind -> Type.valueOf(ind.type),
Expand All @@ -64,6 +70,9 @@ public Map<Type, String> indicatorAnalysis() {
* comparisons 배열을 Map<String, String>로 변환
*/
public Map<String, String> comparisonResults() {
if (comparisons == null) {
return java.util.Collections.emptyMap();
}
return comparisons.stream()
.collect(java.util.stream.Collectors.toMap(
Comparison::type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class BaseScenarioPrompt {
## 현재 삶 정보
베이스라인: {baselineDescription}

## 현재 분기점들
## 과거 주요 인생 기록
{baseNodes}

## 요구사항 (JSON 형식)
Expand Down Expand Up @@ -99,13 +99,13 @@ public static String generatePrompt(BaseLine baseLine) {
BaseNode node = baseNodes.get(i);
int actualYear = birthYear + node.getAgeYear() - 1; // 실제 연도 계산

baseNodesInfo.append(String.format("%d. 카테고리: %s | 나이: %d세 (%d년) | 상황: %s | 결정: %s\n",
baseNodesInfo.append(String.format("%d. 카테고리: %s | 나이: %d세 (%d년) | 사건: %s | 결과: %s\n",
i + 1,
node.getCategory() != null ? node.getCategory().name() : "없음",
node.getAgeYear(),
actualYear,
node.getSituation() != null ? node.getSituation() : "상황 없음",
node.getDecision() != null ? node.getDecision() : "결정 없음"));
node.getSituation() != null ? node.getSituation() : "사건 없음",
node.getDecision() != null ? node.getDecision() : "결과 없음"));

// 가장 최근 노드의 연도를 시나리오 기준 연도로 사용
if (i == baseNodes.size() - 1) {
Expand All @@ -117,13 +117,18 @@ public static String generatePrompt(BaseLine baseLine) {
int currentYear = java.time.LocalDate.now().getYear();
int userCurrentAge = currentYear - birthYear + 1;

// BaseNode들의 실제 연도들을 타임라인 연도로 사용
// 맨 처음과 맨 끝 노드를 제외한 중간 노드들의 연도만 타임라인에 사용
StringBuilder timelineYears = new StringBuilder();
for (int i = 0; i < baseNodes.size(); i++) {
BaseNode node = baseNodes.get(i);
int actualYear = birthYear + node.getAgeYear() - 1;
if (i > 0) timelineYears.append(", ");
timelineYears.append('"').append(actualYear).append('"').append(": \"제목 (5단어 이내)\"");
if (baseNodes.size() > 2) {
java.util.List<BaseNode> intermediateNodes = baseNodes.subList(1, baseNodes.size() - 1);
for (int i = 0; i < intermediateNodes.size(); i++) {
BaseNode node = intermediateNodes.get(i);
int actualYear = birthYear + node.getAgeYear() - 1;
if (i > 0) {
timelineYears.append(", ");
}
timelineYears.append('"').append(actualYear).append('"').append(": \"해당 연도 요약 (5단어 이내)\"");
}
}

// 사용자 정보 추출 (null-safe)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class DecisionScenarioPrompt {
설명: {baseDescription}
지표: {baseIndicators}

## 대안 선택 경로
## 새로운 선택 기록
{decisionNodes}

## 요구사항 (JSON 형식)
Expand Down Expand Up @@ -116,12 +116,12 @@ public static String generatePrompt(DecisionLine decisionLine, Scenario baseScen
int actualYear = birthYear + node.getAgeYear() - 1; // 실제 연도 계산

decisionNodesInfo.append(String.format(
"%d단계 선택 (%d세, %d년):\n상황: %s\n결정: %s\n\n",
"%d단계 선택 (%d세, %d년):\n사건: %s\n선택: %s\n\n",
i + 1,
node.getAgeYear(),
actualYear,
node.getSituation() != null ? node.getSituation() : "상황 정보 없음",
node.getDecision() != null ? node.getDecision() : "결정 정보 없음"
node.getSituation() != null ? node.getSituation() : "사건 정보 없음",
node.getDecision() != null ? node.getDecision() : "선택 정보 없음"
));
}

Expand All @@ -146,13 +146,18 @@ public static String generatePrompt(DecisionLine decisionLine, Scenario baseScen
int careerScore = getScoreByType(baseSceneTypes, "직업");
int healthScore = getScoreByType(baseSceneTypes, "건강");

// DecisionNode들의 실제 연도들을 타임라인 연도로 사용
// 맨 처음과 맨 끝 노드를 제외한 중간 노드들의 연도만 타임라인에 사용
StringBuilder timelineYears = new StringBuilder();
for (int i = 0; i < decisionNodes.size(); i++) {
DecisionNode node = decisionNodes.get(i);
int actualYear = birthYear + node.getAgeYear() - 1;
if (i > 0) timelineYears.append(", ");
timelineYears.append('"').append(actualYear).append('"').append(": \"제목 (5단어 이내)\"");
if (decisionNodes.size() > 2) {
java.util.List<DecisionNode> intermediateNodes = decisionNodes.subList(1, decisionNodes.size() - 1);
for (int i = 0; i < intermediateNodes.size(); i++) {
DecisionNode node = intermediateNodes.get(i);
int actualYear = birthYear + node.getAgeYear() - 1;
if (i > 0) {
timelineYears.append(", ");
}
timelineYears.append('"').append(actualYear).append('"').append(": \"해당 연도 요약 (5단어 이내)\"");
}
}

// 사용자 정보 추출 (null-safe)
Expand Down
55 changes: 48 additions & 7 deletions back/src/main/java/com/back/global/ai/service/AiServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
*/

@Service
@RequiredArgsConstructor
@Slf4j
public class AiServiceImpl implements AiService {

private final @Qualifier("gemini25TextClient") TextAiClient textAiClient;
private final TextAiClient scenarioClient;
private final TextAiClient situationClient;
private final ObjectMapper objectMapper;
private final SceneTypeRepository sceneTypeRepository;
private final SituationAiProperties situationAiProperties;
Expand All @@ -48,6 +48,26 @@ public class AiServiceImpl implements AiService {
private final ImageAiClient imageAiClient;
private final com.back.global.storage.StorageService storageService;

public AiServiceImpl(@Qualifier("gemini25TextClient") TextAiClient scenarioClient,
@Qualifier("gemini20JsonClient") TextAiClient situationClient,
ObjectMapper objectMapper,
SceneTypeRepository sceneTypeRepository,
SituationAiProperties situationAiProperties,
BaseScenarioAiProperties baseScenarioAiProperties,
DecisionScenarioAiProperties decisionScenarioAiProperties,
ImageAiClient imageAiClient,
com.back.global.storage.StorageService storageService) {
this.scenarioClient = scenarioClient;
this.situationClient = situationClient;
this.objectMapper = objectMapper;
this.sceneTypeRepository = sceneTypeRepository;
this.situationAiProperties = situationAiProperties;
this.baseScenarioAiProperties = baseScenarioAiProperties;
this.decisionScenarioAiProperties = decisionScenarioAiProperties;
this.imageAiClient = imageAiClient;
this.storageService = storageService;
}

@Override
public CompletableFuture<BaseScenarioResult> generateBaseScenario(BaseLine baseLine) {
if (baseLine == null) {
Expand All @@ -65,10 +85,21 @@ public CompletableFuture<BaseScenarioResult> generateBaseScenario(BaseLine baseL
// Step 2: AI 호출 및 파싱
int maxTokens = baseScenarioAiProperties.getMaxOutputTokens();
log.info("Using maxOutputTokens: {} for base scenario generation", maxTokens);
AiRequest request = new AiRequest(baseScenarioPrompt, Map.of(), maxTokens);
return textAiClient.generateText(request)

// JSON 모드 강제 + 구조화된 응답 유도 (application.yml에서 관리)
Map<String, Object> generationConfig = Map.of(
"temperature", baseScenarioAiProperties.getTemperature(),
"topP", baseScenarioAiProperties.getTopP(),
"topK", baseScenarioAiProperties.getTopK(),
"candidateCount", 1,
"response_mime_type", "application/json" // JSON 모드 강제
);

AiRequest request = new AiRequest(baseScenarioPrompt, generationConfig, maxTokens);
return scenarioClient.generateText(request)
.thenApply(aiResponse -> {
try {
log.info("Raw AI response for BaseLine ID: {}: {}", baseLine.getId(), aiResponse);
log.debug("Received AI response for BaseLine ID: {}, length: {}",
baseLine.getId(), aiResponse.length());
// Remove markdown code block wrappers (```json ... ```)
Expand Down Expand Up @@ -131,10 +162,20 @@ public CompletableFuture<DecisionScenarioResult> generateDecisionScenario(Decisi
log.debug("Generated decision scenario prompt for DecisionLine ID: {}", decisionLine.getId());

// Step 2: AI 호출 및 파싱
AiRequest request = new AiRequest(newScenarioPrompt, Map.of(), decisionScenarioAiProperties.getMaxOutputTokens());
return textAiClient.generateText(request)
// JSON 모드 강제 + 구조화된 응답 유도 (application.yml에서 관리)
Map<String, Object> generationConfig = Map.of(
"temperature", decisionScenarioAiProperties.getTemperature(),
"topP", decisionScenarioAiProperties.getTopP(),
"topK", decisionScenarioAiProperties.getTopK(),
"candidateCount", 1,
"response_mime_type", "application/json" // JSON 모드 강제
);

AiRequest request = new AiRequest(newScenarioPrompt, generationConfig, decisionScenarioAiProperties.getMaxOutputTokens());
return scenarioClient.generateText(request)
.thenApply(aiResponse -> {
try {
log.info("Raw AI response for DecisionLine ID: {}: {}", decisionLine.getId(), aiResponse);
log.debug("Received AI response for DecisionLine ID: {}, length: {}",
decisionLine.getId(), aiResponse.length());
// Remove markdown code block wrappers (```json ... ```)
Expand Down Expand Up @@ -215,7 +256,7 @@ public CompletableFuture<String> generateSituation(List<DecisionNode> previousNo

// Step 2: AI 호출 및 상황 텍스트 추출
AiRequest request = new AiRequest(situationPrompt, Map.of(), situationAiProperties.getMaxOutputTokens());
return textAiClient.generateText(request)
return situationClient.generateText(request)
.thenApply(aiResponse -> {
try {
log.debug("Received AI response for situation generation, length: {}",
Expand Down
Loading