테스트용 임베딩 모델 추가

This commit is contained in:
mskim 2025-12-15 09:57:37 +09:00
parent 66e02868a6
commit d01de88078
5 changed files with 225 additions and 5 deletions

View File

@ -0,0 +1,16 @@
package com.pandol365.dewey.config;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
@Configuration
public class EmbeddingConfig {
@Bean
public RestTemplate restTemplate(RestTemplateBuilder builder) {
return builder.build();
}
}

View File

@ -0,0 +1,108 @@
package com.pandol365.dewey.domain.memory.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.List;
import java.util.Map;
/**
* 임베딩 생성 클라이언트
* all-mpnet-base-v2 (768d) 같은 로컬/외부 서비스에 HTTP 호출
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class EmbeddingClient {
private final RestTemplate restTemplate;
@Value("${embedding.api.base-url:http://localhost:8000}")
private String embeddingBaseUrl;
@Value("${embedding.api.path:/embed}")
private String embeddingPath;
/**
* 텍스트 임베딩 생성
* 기대 응답 형식: { "embedding": [float, float, ...] }
*/
public float[] embed(String text) {
try {
String url = UriComponentsBuilder.newInstance()
.scheme(extractScheme(embeddingBaseUrl))
.host(extractHost(embeddingBaseUrl))
.port(extractPort(embeddingBaseUrl))
.path(embeddingPath == null ? "/embed" : embeddingPath)
.toUriString();
@SuppressWarnings("unchecked")
Map<String, Object> resp = restTemplate.postForObject(
url,
Map.of("text", text),
Map.class
);
if (resp == null || !resp.containsKey("embedding")) {
throw new IllegalStateException("embedding field missing in response");
}
Object embObj = resp.get("embedding");
if (!(embObj instanceof List<?> list)) {
throw new IllegalStateException("embedding is not a list");
}
float[] embedding = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
Object v = list.get(i);
if (v instanceof Number num) {
embedding[i] = num.floatValue();
} else {
throw new IllegalStateException("embedding element is not numeric");
}
}
return embedding;
} catch (Exception e) {
log.error("임베딩 생성 실패: {}", e.getMessage(), e);
throw new RuntimeException("Failed to generate embedding", e);
}
}
// 간단한 URL 파서 (기본값 포함)
private String extractScheme(String url) {
if (url == null || url.isBlank()) return "http";
int idx = url.indexOf("://");
return idx > 0 ? url.substring(0, idx) : "http";
}
private String extractHost(String url) {
if (url == null || url.isBlank()) return "localhost";
String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url;
int slash = noScheme.indexOf('/');
String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme;
int colon = hostPort.indexOf(':');
return colon >= 0 ? hostPort.substring(0, colon) : hostPort;
}
private Integer extractPort(String url) {
if (url == null || url.isBlank()) return null;
String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url;
int slash = noScheme.indexOf('/');
String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme;
int colon = hostPort.indexOf(':');
if (colon >= 0) {
String portStr = hostPort.substring(colon + 1);
try {
return Integer.parseInt(portStr);
} catch (NumberFormatException ignored) {
return null;
}
}
return null;
}
}

View File

@ -0,0 +1,73 @@
package com.pandol365.dewey.domain.memory.service;
import com.pandol365.dewey.domain.memory.model.Memory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import com.pgvector.PGvector;
import java.time.LocalDateTime;
import java.util.List;
/**
* PGVector 기반 메모리 벡터 스토어
* 별도 테이블 memory_embeddings 저장/검색
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class MemoryVectorStore {
private final JdbcTemplate jdbcTemplate;
private static final String TABLE_NAME = "memory_embeddings";
/**
* 임베딩과 함께 저장 (permanent 목적)
*/
public void save(String userId, String memoryText, Integer importance, float[] embedding) {
try {
PGvector pgVector = new PGvector(embedding);
jdbcTemplate.update(
"INSERT INTO " + TABLE_NAME + " (user_id, memory_text, importance, created_at, embedding) VALUES (?, ?, ?, ?, ?)",
userId,
memoryText,
importance != null ? importance : 1,
LocalDateTime.now(),
pgVector
);
} catch (Exception e) {
log.error("memory_embeddings 저장 실패: {}", e.getMessage(), e);
throw new RuntimeException("Failed to store embedding", e);
}
}
/**
* 코사인 유사도 기반 검색
*/
public List<Memory> search(float[] queryEmbedding, int limit) {
try {
PGvector pgVector = new PGvector(queryEmbedding);
String sql = "SELECT id, user_id, memory_text, importance, created_at, " +
" (embedding <=> ?) AS distance " +
"FROM " + TABLE_NAME + " " +
"ORDER BY embedding <=> ? " +
"LIMIT ?";
return jdbcTemplate.query(sql, (rs, rowNum) -> {
Memory m = new Memory();
m.setId(rs.getLong("id"));
m.setUserId(rs.getString("user_id"));
m.setMemoryText(rs.getString("memory_text"));
m.setImportance(rs.getInt("importance"));
m.setCreatedAt(rs.getTimestamp("created_at").toLocalDateTime());
m.setUpdatedAt(null);
return m;
}, pgVector, pgVector, limit);
} catch (Exception e) {
log.error("memory_embeddings 검색 실패: {}", e.getMessage(), e);
throw new RuntimeException("Failed to search embeddings", e);
}
}
}

View File

@ -4,6 +4,8 @@ import com.pandol365.dewey.domain.memory.model.Memory;
import com.pandol365.dewey.domain.memory.model.TemporaryMemory;
import com.pandol365.dewey.domain.memory.repository.MemoryRepository;
import com.pandol365.dewey.domain.memory.service.MemoryService;
import com.pandol365.dewey.domain.memory.service.EmbeddingClient;
import com.pandol365.dewey.domain.memory.service.MemoryVectorStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.PageRequest;
@ -29,6 +31,8 @@ public class MemoryServiceImpl implements MemoryService {
private final MemoryRepository memoryRepository;
private final RedisTemplate<String, Object> redisTemplate;
private final EmbeddingClient embeddingClient;
private final MemoryVectorStore memoryVectorStore;
private static final String TEMP_MEMORY_KEY_PREFIX = "tempMemory:";
private static final String USER_TEMP_MEMORY_KEY_PREFIX = "user:tempMemories:";
@ -59,6 +63,14 @@ public class MemoryServiceImpl implements MemoryService {
redisTemplate.opsForZSet().add(userMemoriesKey, memoryKey, System.currentTimeMillis());
redisTemplate.expire(userMemoriesKey, ttl);
// 임베딩 생성 PGVector 저장 (실패해도 Redis 저장은 유지)
try {
float[] embedding = embeddingClient.embed(memoryText);
memoryVectorStore.save(userId, memoryText, importance, embedding);
} catch (Exception e) {
log.warn("임베딩 저장 실패 (계속 진행): {}", e.getMessage());
}
return temporaryMemory;
}
@ -109,10 +121,16 @@ public class MemoryServiceImpl implements MemoryService {
@Override
@Transactional(readOnly = true)
public List<Memory> searchMemoriesByVector(String query, String userId, Integer limit) {
log.info("벡터 기반 메모리 검색: query={}, userId={}, limit={}", query, userId, limit);
log.info("벡터 기반 메모리 검색: query={}, userId(optional)={}, limit={}", query, userId, limit);
// TODO: 벡터 유사도 검색 구현 (Spring AI PgVectorStore 사용)
return List.of();
int topK = limit != null && limit > 0 ? limit : 5;
try {
float[] embedding = embeddingClient.embed(query);
return memoryVectorStore.search(embedding, topK);
} catch (Exception e) {
log.error("벡터 검색 실패: {}", e.getMessage(), e);
return List.of();
}
}
@Override

View File

@ -1,7 +1,7 @@
spring.application.name=dewey
# PostgreSQL
spring.datasource.url=jdbc:postgresql://localhost:5432/dewey_memory
spring.datasource.url=jdbc:postgresql://192.168.100.230:5432/dewey_memory
spring.datasource.username=dewey
spring.datasource.password=0bk1rWu98mGl5ea3
spring.datasource.driver-class-name=org.postgresql.Driver
@ -15,3 +15,8 @@ spring.jpa.properties.hibernate.format_sql=true
# Redis
spring.data.redis.host=localhost
spring.data.redis.port=6379
# Embedding service
# all-mpnet-base-v2 768d embedding endpoint
embedding.api.base-url=http://localhost:8000
embedding.api.path=/embed