테스트용 임베딩 모델 추가
This commit is contained in:
parent
66e02868a6
commit
d01de88078
|
|
@ -0,0 +1,16 @@
|
|||
package com.pandol365.dewey.config;
|
||||
|
||||
import org.springframework.boot.web.client.RestTemplateBuilder;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Configuration
|
||||
public class EmbeddingConfig {
|
||||
|
||||
@Bean
|
||||
public RestTemplate restTemplate(RestTemplateBuilder builder) {
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
package com.pandol365.dewey.domain.memory.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 임베딩 생성 클라이언트
|
||||
* all-mpnet-base-v2 (768d) 같은 로컬/외부 서비스에 HTTP 호출
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingClient {
|
||||
|
||||
private final RestTemplate restTemplate;
|
||||
|
||||
@Value("${embedding.api.base-url:http://localhost:8000}")
|
||||
private String embeddingBaseUrl;
|
||||
|
||||
@Value("${embedding.api.path:/embed}")
|
||||
private String embeddingPath;
|
||||
|
||||
/**
|
||||
* 텍스트 임베딩 생성
|
||||
* 기대 응답 형식: { "embedding": [float, float, ...] }
|
||||
*/
|
||||
public float[] embed(String text) {
|
||||
try {
|
||||
String url = UriComponentsBuilder.newInstance()
|
||||
.scheme(extractScheme(embeddingBaseUrl))
|
||||
.host(extractHost(embeddingBaseUrl))
|
||||
.port(extractPort(embeddingBaseUrl))
|
||||
.path(embeddingPath == null ? "/embed" : embeddingPath)
|
||||
.toUriString();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resp = restTemplate.postForObject(
|
||||
url,
|
||||
Map.of("text", text),
|
||||
Map.class
|
||||
);
|
||||
|
||||
if (resp == null || !resp.containsKey("embedding")) {
|
||||
throw new IllegalStateException("embedding field missing in response");
|
||||
}
|
||||
|
||||
Object embObj = resp.get("embedding");
|
||||
if (!(embObj instanceof List<?> list)) {
|
||||
throw new IllegalStateException("embedding is not a list");
|
||||
}
|
||||
|
||||
float[] embedding = new float[list.size()];
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
Object v = list.get(i);
|
||||
if (v instanceof Number num) {
|
||||
embedding[i] = num.floatValue();
|
||||
} else {
|
||||
throw new IllegalStateException("embedding element is not numeric");
|
||||
}
|
||||
}
|
||||
return embedding;
|
||||
} catch (Exception e) {
|
||||
log.error("임베딩 생성 실패: {}", e.getMessage(), e);
|
||||
throw new RuntimeException("Failed to generate embedding", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 간단한 URL 파서 (기본값 포함)
|
||||
private String extractScheme(String url) {
|
||||
if (url == null || url.isBlank()) return "http";
|
||||
int idx = url.indexOf("://");
|
||||
return idx > 0 ? url.substring(0, idx) : "http";
|
||||
}
|
||||
|
||||
private String extractHost(String url) {
|
||||
if (url == null || url.isBlank()) return "localhost";
|
||||
String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url;
|
||||
int slash = noScheme.indexOf('/');
|
||||
String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme;
|
||||
int colon = hostPort.indexOf(':');
|
||||
return colon >= 0 ? hostPort.substring(0, colon) : hostPort;
|
||||
}
|
||||
|
||||
private Integer extractPort(String url) {
|
||||
if (url == null || url.isBlank()) return null;
|
||||
String noScheme = url.contains("://") ? url.substring(url.indexOf("://") + 3) : url;
|
||||
int slash = noScheme.indexOf('/');
|
||||
String hostPort = slash >= 0 ? noScheme.substring(0, slash) : noScheme;
|
||||
int colon = hostPort.indexOf(':');
|
||||
if (colon >= 0) {
|
||||
String portStr = hostPort.substring(colon + 1);
|
||||
try {
|
||||
return Integer.parseInt(portStr);
|
||||
} catch (NumberFormatException ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
package com.pandol365.dewey.domain.memory.service;
|
||||
|
||||
import com.pandol365.dewey.domain.memory.model.Memory;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
import com.pgvector.PGvector;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* PGVector 기반 메모리 벡터 스토어
|
||||
* 별도 테이블 memory_embeddings 에 저장/검색
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class MemoryVectorStore {
|
||||
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
private static final String TABLE_NAME = "memory_embeddings";
|
||||
|
||||
/**
|
||||
* 임베딩과 함께 저장 (permanent 목적)
|
||||
*/
|
||||
public void save(String userId, String memoryText, Integer importance, float[] embedding) {
|
||||
try {
|
||||
PGvector pgVector = new PGvector(embedding);
|
||||
jdbcTemplate.update(
|
||||
"INSERT INTO " + TABLE_NAME + " (user_id, memory_text, importance, created_at, embedding) VALUES (?, ?, ?, ?, ?)",
|
||||
userId,
|
||||
memoryText,
|
||||
importance != null ? importance : 1,
|
||||
LocalDateTime.now(),
|
||||
pgVector
|
||||
);
|
||||
} catch (Exception e) {
|
||||
log.error("memory_embeddings 저장 실패: {}", e.getMessage(), e);
|
||||
throw new RuntimeException("Failed to store embedding", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 코사인 유사도 기반 검색
|
||||
*/
|
||||
public List<Memory> search(float[] queryEmbedding, int limit) {
|
||||
try {
|
||||
PGvector pgVector = new PGvector(queryEmbedding);
|
||||
String sql = "SELECT id, user_id, memory_text, importance, created_at, " +
|
||||
" (embedding <=> ?) AS distance " +
|
||||
"FROM " + TABLE_NAME + " " +
|
||||
"ORDER BY embedding <=> ? " +
|
||||
"LIMIT ?";
|
||||
return jdbcTemplate.query(sql, (rs, rowNum) -> {
|
||||
Memory m = new Memory();
|
||||
m.setId(rs.getLong("id"));
|
||||
m.setUserId(rs.getString("user_id"));
|
||||
m.setMemoryText(rs.getString("memory_text"));
|
||||
m.setImportance(rs.getInt("importance"));
|
||||
m.setCreatedAt(rs.getTimestamp("created_at").toLocalDateTime());
|
||||
m.setUpdatedAt(null);
|
||||
return m;
|
||||
}, pgVector, pgVector, limit);
|
||||
} catch (Exception e) {
|
||||
log.error("memory_embeddings 검색 실패: {}", e.getMessage(), e);
|
||||
throw new RuntimeException("Failed to search embeddings", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4,6 +4,8 @@ import com.pandol365.dewey.domain.memory.model.Memory;
|
|||
import com.pandol365.dewey.domain.memory.model.TemporaryMemory;
|
||||
import com.pandol365.dewey.domain.memory.repository.MemoryRepository;
|
||||
import com.pandol365.dewey.domain.memory.service.MemoryService;
|
||||
import com.pandol365.dewey.domain.memory.service.EmbeddingClient;
|
||||
import com.pandol365.dewey.domain.memory.service.MemoryVectorStore;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.data.domain.PageRequest;
|
||||
|
|
@ -29,6 +31,8 @@ public class MemoryServiceImpl implements MemoryService {
|
|||
|
||||
private final MemoryRepository memoryRepository;
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final EmbeddingClient embeddingClient;
|
||||
private final MemoryVectorStore memoryVectorStore;
|
||||
|
||||
private static final String TEMP_MEMORY_KEY_PREFIX = "tempMemory:";
|
||||
private static final String USER_TEMP_MEMORY_KEY_PREFIX = "user:tempMemories:";
|
||||
|
|
@ -59,6 +63,14 @@ public class MemoryServiceImpl implements MemoryService {
|
|||
redisTemplate.opsForZSet().add(userMemoriesKey, memoryKey, System.currentTimeMillis());
|
||||
redisTemplate.expire(userMemoriesKey, ttl);
|
||||
|
||||
// 임베딩 생성 및 PGVector 저장 (실패해도 Redis 저장은 유지)
|
||||
try {
|
||||
float[] embedding = embeddingClient.embed(memoryText);
|
||||
memoryVectorStore.save(userId, memoryText, importance, embedding);
|
||||
} catch (Exception e) {
|
||||
log.warn("임베딩 저장 실패 (계속 진행): {}", e.getMessage());
|
||||
}
|
||||
|
||||
return temporaryMemory;
|
||||
}
|
||||
|
||||
|
|
@ -109,11 +121,17 @@ public class MemoryServiceImpl implements MemoryService {
|
|||
@Override
|
||||
@Transactional(readOnly = true)
|
||||
public List<Memory> searchMemoriesByVector(String query, String userId, Integer limit) {
|
||||
log.info("벡터 기반 메모리 검색: query={}, userId={}, limit={}", query, userId, limit);
|
||||
log.info("벡터 기반 메모리 검색: query={}, userId(optional)={}, limit={}", query, userId, limit);
|
||||
|
||||
// TODO: 벡터 유사도 검색 구현 (Spring AI PgVectorStore 사용)
|
||||
int topK = limit != null && limit > 0 ? limit : 5;
|
||||
try {
|
||||
float[] embedding = embeddingClient.embed(query);
|
||||
return memoryVectorStore.search(embedding, topK);
|
||||
} catch (Exception e) {
|
||||
log.error("벡터 검색 실패: {}", e.getMessage(), e);
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Transactional(readOnly = true)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
spring.application.name=dewey
|
||||
|
||||
# PostgreSQL
|
||||
spring.datasource.url=jdbc:postgresql://localhost:5432/dewey_memory
|
||||
spring.datasource.url=jdbc:postgresql://192.168.100.230:5432/dewey_memory
|
||||
spring.datasource.username=dewey
|
||||
spring.datasource.password=0bk1rWu98mGl5ea3
|
||||
spring.datasource.driver-class-name=org.postgresql.Driver
|
||||
|
|
@ -15,3 +15,8 @@ spring.jpa.properties.hibernate.format_sql=true
|
|||
# Redis
|
||||
spring.data.redis.host=localhost
|
||||
spring.data.redis.port=6379
|
||||
|
||||
# Embedding service
|
||||
# all-mpnet-base-v2 768d embedding endpoint
|
||||
embedding.api.base-url=http://localhost:8000
|
||||
embedding.api.path=/embed
|
||||
|
|
|
|||
Loading…
Reference in New Issue