diff --git a/src/main/java/com/pandol365/dewey/config/EmbeddingConfig.java b/src/main/java/com/pandol365/dewey/config/EmbeddingConfig.java new file mode 100644 index 0000000..8d08c7f --- /dev/null +++ b/src/main/java/com/pandol365/dewey/config/EmbeddingConfig.java @@ -0,0 +1,16 @@ +package com.pandol365.dewey.config; + +import org.springframework.boot.web.client.RestTemplateBuilder; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestTemplate; + +@Configuration +public class EmbeddingConfig { + + @Bean + public RestTemplate restTemplate(RestTemplateBuilder builder) { + return builder.build(); + } +} + diff --git a/src/main/java/com/pandol365/dewey/domain/memory/service/EmbeddingClient.java b/src/main/java/com/pandol365/dewey/domain/memory/service/EmbeddingClient.java new file mode 100644 index 0000000..1a44230 --- /dev/null +++ b/src/main/java/com/pandol365/dewey/domain/memory/service/EmbeddingClient.java @@ -0,0 +1,108 @@ +package com.pandol365.dewey.domain.memory.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; + +import java.util.List; +import java.util.Map; + +/** + * 임베딩 생성 클라이언트 + * all-mpnet-base-v2 (768d) 같은 로컬/외부 서비스에 HTTP 호출 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class EmbeddingClient { + + private final RestTemplate restTemplate; + + @Value("${embedding.api.base-url:http://localhost:8000}") + private String embeddingBaseUrl; + + @Value("${embedding.api.path:/embed}") + private String embeddingPath; + + /** + * 텍스트 임베딩 생성 + * 기대 응답 형식: { "embedding": [float, float, ...] } + */ + public float[] embed(String text) { + try { + String url = UriComponentsBuilder.newInstance() + .scheme(extractScheme(embeddingBaseUrl)) + .host(extractHost(embeddingBaseUrl)) + .port(extractPort(embeddingBaseUrl)) + .path(embeddingPath == null ? "/embed" : embeddingPath) + .toUriString(); + + @SuppressWarnings("unchecked") + Map resp = restTemplate.postForObject( + url, + Map.of("text", text), + Map.class + ); + + if (resp == null || !resp.containsKey("embedding")) { + throw new IllegalStateException("embedding field missing in response"); + } + + Object embObj = resp.get("embedding"); + if (!(embObj instanceof List list)) { + throw new IllegalStateException("embedding is not a list"); + } + + float[] embedding = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + Object v = list.get(i); + if (v instanceof Number num) { + embedding[i] = num.floatValue(); + } else { + throw new IllegalStateException("embedding element is not numeric"); + } + } + return embedding; + } catch (Exception e) { + log.error("임베딩 생성 실패: {}", e.getMessage(), e); + throw new RuntimeException("Failed to generate embedding", e); + } + } + + // 간단한 URL 파서 (기본값 포함) + private String extractScheme(String url) { + if (url == null || url.isBlank()) return "http"; + int idx = url.indexOf("://"); + return idx > 0 ? url.substring(0, idx) : "http"; + } + + private String extractHost(String url) { + if (url == null || url.isBlank()) return "localhost"; + String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url; + int slash = noScheme.indexOf('/'); + String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme; + int colon = hostPort.indexOf(':'); + return colon >= 0 ? hostPort.substring(0, colon) : hostPort; + } + + private Integer extractPort(String url) { + if (url == null || url.isBlank()) return null; + String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url; + int slash = noScheme.indexOf('/'); + String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme; + int colon = hostPort.indexOf(':'); + if (colon >= 0) { + String portStr = hostPort.substring(colon + 1); + try { + return Integer.parseInt(portStr); + } catch (NumberFormatException ignored) { + return null; + } + } + return null; + } +} + diff --git a/src/main/java/com/pandol365/dewey/domain/memory/service/MemoryVectorStore.java b/src/main/java/com/pandol365/dewey/domain/memory/service/MemoryVectorStore.java new file mode 100644 index 0000000..8dca3a7 --- /dev/null +++ b/src/main/java/com/pandol365/dewey/domain/memory/service/MemoryVectorStore.java @@ -0,0 +1,73 @@ +package com.pandol365.dewey.domain.memory.service; + +import com.pandol365.dewey.domain.memory.model.Memory; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Component; +import com.pgvector.PGvector; + +import java.time.LocalDateTime; +import java.util.List; + +/** + * PGVector 기반 메모리 벡터 스토어 + * 별도 테이블 memory_embeddings 에 저장/검색 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class MemoryVectorStore { + + private final JdbcTemplate jdbcTemplate; + + private static final String TABLE_NAME = "memory_embeddings"; + + /** + * 임베딩과 함께 저장 (permanent 목적) + */ + public void save(String userId, String memoryText, Integer importance, float[] embedding) { + try { + PGvector pgVector = new PGvector(embedding); + jdbcTemplate.update( + "INSERT INTO " + TABLE_NAME + " (user_id, memory_text, importance, created_at, embedding) VALUES (?, ?, ?, ?, ?)", + userId, + memoryText, + importance != null ? importance : 1, + LocalDateTime.now(), + pgVector + ); + } catch (Exception e) { + log.error("memory_embeddings 저장 실패: {}", e.getMessage(), e); + throw new RuntimeException("Failed to store embedding", e); + } + } + + /** + * 코사인 유사도 기반 검색 + */ + public List search(float[] queryEmbedding, int limit) { + try { + PGvector pgVector = new PGvector(queryEmbedding); + String sql = "SELECT id, user_id, memory_text, importance, created_at, " + + " (embedding <=> ?) AS distance " + + "FROM " + TABLE_NAME + " " + + "ORDER BY embedding <=> ? " + + "LIMIT ?"; + return jdbcTemplate.query(sql, (rs, rowNum) -> { + Memory m = new Memory(); + m.setId(rs.getLong("id")); + m.setUserId(rs.getString("user_id")); + m.setMemoryText(rs.getString("memory_text")); + m.setImportance(rs.getInt("importance")); + m.setCreatedAt(rs.getTimestamp("created_at").toLocalDateTime()); + m.setUpdatedAt(null); + return m; + }, pgVector, pgVector, limit); + } catch (Exception e) { + log.error("memory_embeddings 검색 실패: {}", e.getMessage(), e); + throw new RuntimeException("Failed to search embeddings", e); + } + } +} + diff --git a/src/main/java/com/pandol365/dewey/domain/memory/service/impl/MemoryServiceImpl.java b/src/main/java/com/pandol365/dewey/domain/memory/service/impl/MemoryServiceImpl.java index 1119d78..6065b52 100644 --- a/src/main/java/com/pandol365/dewey/domain/memory/service/impl/MemoryServiceImpl.java +++ b/src/main/java/com/pandol365/dewey/domain/memory/service/impl/MemoryServiceImpl.java @@ -4,6 +4,8 @@ import com.pandol365.dewey.domain.memory.model.Memory; import com.pandol365.dewey.domain.memory.model.TemporaryMemory; import com.pandol365.dewey.domain.memory.repository.MemoryRepository; import com.pandol365.dewey.domain.memory.service.MemoryService; +import com.pandol365.dewey.domain.memory.service.EmbeddingClient; +import com.pandol365.dewey.domain.memory.service.MemoryVectorStore; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.data.domain.PageRequest; @@ -29,6 +31,8 @@ public class MemoryServiceImpl implements MemoryService { private final MemoryRepository memoryRepository; private final RedisTemplate redisTemplate; + private final EmbeddingClient embeddingClient; + private final MemoryVectorStore memoryVectorStore; private static final String TEMP_MEMORY_KEY_PREFIX = "tempMemory:"; private static final String USER_TEMP_MEMORY_KEY_PREFIX = "user:tempMemories:"; @@ -58,6 +62,14 @@ public class MemoryServiceImpl implements MemoryService { String userMemoriesKey = USER_TEMP_MEMORY_KEY_PREFIX + userId; redisTemplate.opsForZSet().add(userMemoriesKey, memoryKey, System.currentTimeMillis()); redisTemplate.expire(userMemoriesKey, ttl); + + // 임베딩 생성 및 PGVector 저장 (실패해도 Redis 저장은 유지) + try { + float[] embedding = embeddingClient.embed(memoryText); + memoryVectorStore.save(userId, memoryText, importance, embedding); + } catch (Exception e) { + log.warn("임베딩 저장 실패 (계속 진행): {}", e.getMessage()); + } return temporaryMemory; } @@ -109,10 +121,16 @@ public class MemoryServiceImpl implements MemoryService { @Override @Transactional(readOnly = true) public List searchMemoriesByVector(String query, String userId, Integer limit) { - log.info("벡터 기반 메모리 검색: query={}, userId={}, limit={}", query, userId, limit); - - // TODO: 벡터 유사도 검색 구현 (Spring AI PgVectorStore 사용) - return List.of(); + log.info("벡터 기반 메모리 검색: query={}, userId(optional)={}, limit={}", query, userId, limit); + + int topK = limit != null && limit > 0 ? limit : 5; + try { + float[] embedding = embeddingClient.embed(query); + return memoryVectorStore.search(embedding, topK); + } catch (Exception e) { + log.error("벡터 검색 실패: {}", e.getMessage(), e); + return List.of(); + } } @Override diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 1f931c3..c90fb00 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,7 +1,7 @@ spring.application.name=dewey # PostgreSQL -spring.datasource.url=jdbc:postgresql://localhost:5432/dewey_memory +spring.datasource.url=jdbc:postgresql://192.168.100.230:5432/dewey_memory spring.datasource.username=dewey spring.datasource.password=0bk1rWu98mGl5ea3 spring.datasource.driver-class-name=org.postgresql.Driver @@ -15,3 +15,8 @@ spring.jpa.properties.hibernate.format_sql=true # Redis spring.data.redis.host=localhost spring.data.redis.port=6379 + +# Embedding service +# all-mpnet-base-v2 768d embedding endpoint +embedding.api.base-url=http://localhost:8000 +embedding.api.path=/embed