Go实现微信小程序智能AI绘画与艺术创作平台
Go实现微信小程序智能AI绘画与艺术创作平台
·
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/gorilla/mux"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// 微信小程序AI绘画平台
type WXAIPaintingPlatform struct {
db *gorm.DB
userManager *UserManager
aiPaintingEngine *AIPaintingEngine
styleManager *StyleManager
collectionManager *CollectionManager
orderSystem *OrderSystem
wxAPI WXAPI
storageService *StorageService
notification *NotificationSystem
socialManager *SocialManager
modelTraining *ModelTrainingSystem
contentFilter *ContentFilter
dataAnalytics *DataAnalytics
artworkManager *ArtworkManager
}
// 初始化AI绘画平台
func NewWXAIPaintingPlatform(dsn string) (*WXAIPaintingPlatform, error) {
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("数据库连接失败: %v", err)
}
return &WXAIPaintingPlatform{
db: db,
userManager: NewUserManager(db),
aiPaintingEngine: NewAIPaintingEngine(),
styleManager: NewStyleManager(db),
collectionManager: NewCollectionManager(db),
orderSystem: NewOrderSystem(db),
wxAPI: NewWXAPIImpl(),
storageService: NewStorageService(),
notification: NewNotificationSystem(),
socialManager: NewSocialManager(db),
modelTraining: NewModelTrainingSystem(),
contentFilter: NewContentFilter(),
dataAnalytics: NewDataAnalytics(db),
artworkManager: NewArtworkManager(db),
}, nil
}
// 启动平台
func (p *WXAIPaintingPlatform) Start(ctx context.Context) error {
log.Println("AI绘画平台启动中...")
// 初始化各子系统
if err := p.userManager.Init(ctx); err != nil {
return fmt.Errorf("用户系统初始化失败: %v", err)
}
if err := p.styleManager.LoadStyles(ctx); err != nil {
return fmt.Errorf("风格系统初始化失败: %v", err)
}
if err := p.aiPaintingEngine.InitModels(ctx); err != nil {
return fmt.Errorf("AI模型初始化失败: %v", err)
}
// 启动后台服务
go p.monitorModelTraining(ctx)
go p.cleanExpiredResources(ctx)
go p.generateDailyReports(ctx)
log.Println("AI绘画平台启动完成")
return nil
}
// 用户注册
func (p *WXAIPaintingPlatform) RegisterUser(ctx context.Context, user User) (*User, error) {
// 1. 验证微信用户信息
wxUser, err := p.wxAPI.GetUserInfo(user.WXOpenID)
if err != nil {
return nil, fmt.Errorf("获取微信用户信息失败: %v", err)
}
// 2. 创建用户账号
newUser, err := p.userManager.CreateUser(ctx, user, wxUser)
if err != nil {
return nil, fmt.Errorf("创建用户失败: %v", err)
}
// 3. 发放新手礼包
if err := p.giveNewUserGift(ctx, newUser.ID); err != nil {
log.Printf("发放新手礼包失败: %v", err)
}
// 4. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, newUser.ID, "register", nil)
return newUser, nil
}
// 发放新手礼包
func (p *WXAIPaintingPlatform) giveNewUserGift(ctx context.Context, userID uint) error {
// 1. 发放免费生成次数
if err := p.userManager.AddFreeGenerations(ctx, userID, 5); err != nil {
return fmt.Errorf("发放免费生成次数失败: %v", err)
}
// 2. 发放VIP试用
if err := p.userManager.AddVIPDays(ctx, userID, 7); err != nil {
return fmt.Errorf("发放VIP试用失败: %v", err)
}
// 3. 发送欢迎消息
if err := p.notification.SendWelcomeMessage(ctx, userID); err != nil {
return fmt.Errorf("发送欢迎消息失败: %v", err)
}
return nil
}
// 用户登录
func (p *WXAIPaintingPlatform) Login(ctx context.Context, code string) (*UserSession, error) {
// 1. 微信登录获取openid
session, err := p.wxAPI.Code2Session(ctx, code)
if err != nil {
return nil, fmt.Errorf("微信登录失败: %v", err)
}
// 2. 获取或创建用户
user, err := p.userManager.GetOrCreateByWXOpenID(ctx, session.OpenID)
if err != nil {
return nil, fmt.Errorf("获取用户信息失败: %v", err)
}
// 3. 创建会话
token, err := p.userManager.CreateSession(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("创建会话失败: %v", err)
}
// 4. 记录登录
p.dataAnalytics.RecordUserAction(ctx, user.ID, "login", nil)
return &UserSession{
Token: token,
UserID: user.ID,
UserInfo: user,
}, nil
}
// 获取绘画风格列表
func (p *WXAIPaintingPlatform) GetPaintingStyles(ctx context.Context, category string) ([]PaintingStyle, error) {
return p.styleManager.GetStylesByCategory(ctx, category)
}
// 获取热门风格
func (p *WXAIPaintingPlatform) GetPopularStyles(ctx context.Context, limit int) ([]PaintingStyle, error) {
return p.styleManager.GetPopularStyles(ctx, limit)
}
// 获取用户收藏风格
func (p *WXAIPaintingPlatform) GetUserFavoriteStyles(ctx context.Context, userID uint) ([]PaintingStyle, error) {
return p.collectionManager.GetUserFavoriteStyles(ctx, userID)
}
// 收藏风格
func (p *WXAIPaintingPlatform) AddFavoriteStyle(ctx context.Context, userID, styleID uint) error {
return p.collectionManager.AddFavoriteStyle(ctx, userID, styleID)
}
// 取消收藏风格
func (p *WXAIPaintingPlatform) RemoveFavoriteStyle(ctx context.Context, userID, styleID uint) error {
return p.collectionManager.RemoveFavoriteStyle(ctx, userID, styleID)
}
// 生成AI绘画
func (p *WXAIPaintingPlatform) GenerateAIPainting(ctx context.Context, req GenerateRequest) (*GenerateResult, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 检查用户生成权限
canGenerate, err := p.userManager.CanGenerate(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("检查生成权限失败: %v", err)
}
if !canGenerate {
return nil, errors.New("生成次数不足,请购买或等待重置")
}
// 3. 获取风格信息
style, err := p.styleManager.GetStyle(ctx, req.StyleID)
if err != nil {
return nil, fmt.Errorf("获取风格信息失败: %v", err)
}
// 4. 内容安全检查
if err := p.contentFilter.CheckText(req.Prompt); err != nil {
return nil, fmt.Errorf("内容安全检查失败: %v", err)
}
// 5. 调用AI生成
startTime := time.Now()
imageData, err := p.aiPaintingEngine.Generate(ctx, req.Prompt, style.ModelName, req.Params)
if err != nil {
return nil, fmt.Errorf("AI生成失败: %v", err)
}
generationTime := time.Since(startTime)
// 6. 保存生成结果
imageURL, err := p.storageService.UploadImage(ctx, imageData)
if err != nil {
return nil, fmt.Errorf("保存图片失败: %v", err)
}
// 7. 记录生成历史
history, err := p.artworkManager.RecordGeneration(ctx, user.ID, req, imageURL, generationTime)
if err != nil {
return nil, fmt.Errorf("记录生成历史失败: %v", err)
}
// 8. 扣减生成次数
if err := p.userManager.UseFreeGeneration(ctx, user.ID); err != nil {
log.Printf("扣减生成次数失败: %v", err)
}
// 9. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, user.ID, "generate_art", map[string]interface{}{
"style_id": req.StyleID,
"time_cost": generationTime.Seconds(),
})
return &GenerateResult{
ImageURL: imageURL,
HistoryID: history.ID,
TimeCost: generationTime,
Style: *style,
}, nil
}
// 基于图片生成AI绘画
func (p *WXAIPaintingPlatform) GenerateFromImage(ctx context.Context, req GenerateFromImageRequest) (*GenerateResult, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 检查用户生成权限
canGenerate, err := p.userManager.CanGenerate(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("检查生成权限失败: %v", err)
}
if !canGenerate {
return nil, errors.New("生成次数不足,请购买或等待重置")
}
// 3. 获取风格信息
style, err := p.styleManager.GetStyle(ctx, req.StyleID)
if err != nil {
return nil, fmt.Errorf("获取风格信息失败: %v", err)
}
// 4. 内容安全检查
if err := p.contentFilter.CheckImage(req.ImageData); err != nil {
return nil, fmt.Errorf("内容安全检查失败: %v", err)
}
// 5. 上传原始图片
sourceImageURL, err := p.storageService.UploadImage(ctx, req.ImageData)
if err != nil {
return nil, fmt.Errorf("上传原始图片失败: %v", err)
}
// 6. 调用AI生成
startTime := time.Now()
imageData, err := p.aiPaintingEngine.GenerateFromImage(ctx, req.ImageData, req.Prompt, style.ModelName, req.Params)
if err != nil {
return nil, fmt.Errorf("AI生成失败: %v", err)
}
generationTime := time.Since(startTime)
// 7. 保存生成结果
imageURL, err := p.storageService.UploadImage(ctx, imageData)
if err != nil {
return nil, fmt.Errorf("保存图片失败: %v", err)
}
// 8. 记录生成历史
history, err := p.artworkManager.RecordGenerationFromImage(ctx, user.ID, req, sourceImageURL, imageURL, generationTime)
if err != nil {
return nil, fmt.Errorf("记录生成历史失败: %v", err)
}
// 9. 扣减生成次数
if err := p.userManager.UseFreeGeneration(ctx, user.ID); err != nil {
log.Printf("扣减生成次数失败: %v", err)
}
// 10. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, user.ID, "generate_from_image", map[string]interface{}{
"style_id": req.StyleID,
"time_cost": generationTime.Seconds(),
})
return &GenerateResult{
ImageURL: imageURL,
HistoryID: history.ID,
TimeCost: generationTime,
Style: *style,
}, nil
}
// 获取生成历史
func (p *WXAIPaintingPlatform) GetGenerationHistory(ctx context.Context, userID uint, page, pageSize int) ([]GenerationHistory, error) {
return p.artworkManager.GetUserHistory(ctx, userID, page, pageSize)
}
// 获取生成历史详情
func (p *WXAIPaintingPlatform) GetGenerationDetail(ctx context.Context, historyID uint) (*GenerationDetail, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 获取历史记录
history, err := p.artworkManager.GetHistory(ctx, historyID)
if err != nil {
return nil, fmt.Errorf("获取历史记录失败: %v", err)
}
// 3. 验证所有者
if history.UserID != user.ID {
return nil, errors.New("无权访问此记录")
}
// 4. 获取风格信息
style, err := p.styleManager.GetStyle(ctx, history.StyleID)
if err != nil {
return nil, fmt.Errorf("获取风格信息失败: %v", err)
}
return &GenerationDetail{
History: *history,
Style: *style,
}, nil
}
// 保存作品到画廊
func (p *WXAIPaintingPlatform) SaveToGallery(ctx context.Context, historyID uint, req SaveToGalleryRequest) (*Artwork, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 获取历史记录
history, err := p.artworkManager.GetHistory(ctx, historyID)
if err != nil {
return nil, fmt.Errorf("获取历史记录失败: %v", err)
}
// 3. 验证所有者
if history.UserID != user.ID {
return nil, errors.New("无权操作此记录")
}
// 4. 保存到画廊
artwork, err := p.artworkManager.SaveToGallery(ctx, history.ID, user.ID, req.Title, req.Description, req.Tags)
if err != nil {
return nil, fmt.Errorf("保存到画廊失败: %v", err)
}
// 5. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, user.ID, "save_to_gallery", map[string]interface{}{
"artwork_id": artwork.ID,
})
return artwork, nil
}
// 获取用户画廊
func (p *WXAIPaintingPlatform) GetUserGallery(ctx context.Context, userID uint, page, pageSize int) ([]Artwork, error) {
return p.artworkManager.GetUserGallery(ctx, userID, page, pageSize)
}
// 获取热门作品
func (p *WXAIPaintingPlatform) GetPopularArtworks(ctx context.Context, page, pageSize int) ([]Artwork, error) {
return p.artworkManager.GetPopularArtworks(ctx, page, pageSize)
}
// 获取作品详情
func (p *WXAIPaintingPlatform) GetArtworkDetail(ctx context.Context, artworkID uint) (*ArtworkDetail, error) {
// 1. 获取作品基本信息
artwork, err := p.artworkManager.GetArtwork(ctx, artworkID)
if err != nil {
return nil, fmt.Errorf("获取作品失败: %v", err)
}
// 2. 获取作者信息
author, err := p.userManager.GetUser(ctx, artwork.UserID)
if err != nil {
return nil, fmt.Errorf("获取作者信息失败: %v", err)
}
// 3. 获取风格信息
style, err := p.styleManager.GetStyle(ctx, artwork.StyleID)
if err != nil {
return nil, fmt.Errorf("获取风格信息失败: %v", err)
}
// 4. 获取统计信息
stats, err := p.artworkManager.GetArtworkStats(ctx, artworkID)
if err != nil {
return nil, fmt.Errorf("获取作品统计失败: %v", err)
}
// 5. 检查当前用户是否收藏
var isCollected bool
if user, ok := UserFromContext(ctx); ok {
isCollected = p.collectionManager.IsArtworkCollected(ctx, user.ID, artworkID)
}
return &ArtworkDetail{
Artwork: *artwork,
Author: *author,
Style: *style,
Stats: *stats,
IsCollected: isCollected,
}, nil
}
// 收藏作品
func (p *WXAIPaintingPlatform) CollectArtwork(ctx context.Context, userID, artworkID uint) error {
return p.collectionManager.CollectArtwork(ctx, userID, artworkID)
}
// 取消收藏作品
func (p *WXAIPaintingPlatform) UncollectArtwork(ctx context.Context, userID, artworkID uint) error {
return p.collectionManager.UncollectArtwork(ctx, userID, artworkID)
}
// 点赞作品
func (p *WXAIPaintingPlatform) LikeArtwork(ctx context.Context, userID, artworkID uint) error {
// 1. 点赞作品
if err := p.artworkManager.LikeArtwork(ctx, userID, artworkID); err != nil {
return fmt.Errorf("点赞作品失败: %v", err)
}
// 2. 获取作品作者
artwork, err := p.artworkManager.GetArtwork(ctx, artworkID)
if err != nil {
return fmt.Errorf("获取作品信息失败: %v", err)
}
// 3. 如果不是自己的作品,发送通知
if artwork.UserID != userID {
if err := p.notification.SendLikeNotification(ctx, userID, artwork.UserID, artworkID); err != nil {
log.Printf("发送点赞通知失败: %v", err)
}
}
// 4. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, userID, "like_artwork", map[string]interface{}{
"artwork_id": artworkID,
})
return nil
}
// 取消点赞
func (p *WXAIPaintingPlatform) UnlikeArtwork(ctx context.Context, userID, artworkID uint) error {
return p.artworkManager.UnlikeArtwork(ctx, userID, artworkID)
}
// 评论作品
func (p *WXAIPaintingPlatform) CommentArtwork(ctx context.Context, userID, artworkID uint, content string) (*Comment, error) {
// 1. 内容安全检查
if err := p.contentFilter.CheckText(content); err != nil {
return nil, fmt.Errorf("内容安全检查失败: %v", err)
}
// 2. 添加评论
comment, err := p.artworkManager.AddComment(ctx, userID, artworkID, content)
if err != nil {
return nil, fmt.Errorf("添加评论失败: %v", err)
}
// 3. 获取作品作者
artwork, err := p.artworkManager.GetArtwork(ctx, artworkID)
if err != nil {
return nil, fmt.Errorf("获取作品信息失败: %v", err)
}
// 4. 如果不是自己的作品,发送通知
if artwork.UserID != userID {
if err := p.notification.SendCommentNotification(ctx, userID, artwork.UserID, artworkID, comment.ID); err != nil {
log.Printf("发送评论通知失败: %v", err)
}
}
// 5. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, userID, "comment_artwork", map[string]interface{}{
"artwork_id": artworkID,
})
return comment, nil
}
// 获取作品评论
func (p *WXAIPaintingPlatform) GetArtworkComments(ctx context.Context, artworkID uint, page, pageSize int) ([]Comment, error) {
return p.artworkManager.GetArtworkComments(ctx, artworkID, page, pageSize)
}
// 分享作品
func (p *WXAIPaintingPlatform) ShareArtwork(ctx context.Context, userID, artworkID uint, platform string) (*ShareResult, error) {
// 1. 获取作品信息
artwork, err := p.artworkManager.GetArtwork(ctx, artworkID)
if err != nil {
return nil, fmt.Errorf("获取作品信息失败: %v", err)
}
// 2. 生成分享链接
shareLink, err := p.generateShareLink(ctx, userID, artworkID, platform)
if err != nil {
return nil, fmt.Errorf("生成分享链接失败: %v", err)
}
// 3. 记录分享
if err := p.dataAnalytics.RecordShare(ctx, userID, artworkID, platform); err != nil {
return nil, fmt.Errorf("记录分享失败: %v", err)
}
// 4. 检查是否有分享奖励
if reward, err := p.userManager.CheckShareReward(ctx, userID); err == nil && reward != nil {
// 有分享奖励
return &ShareResult{
Link: shareLink,
Image: artwork.ImageURL,
Title: artwork.Title,
HasReward: true,
Reward: *reward,
}, nil
}
return &ShareResult{
Link: shareLink,
Image: artwork.ImageURL,
Title: artwork.Title,
HasReward: false,
}, nil
}
// 生成分享链接
func (p *WXAIPaintingPlatform) generateShareLink(ctx context.Context, userID, artworkID uint, platform string) (string, error) {
// 在实际应用中,这里应该生成带有追踪参数的分享链接
return fmt.Sprintf("https://art.example.com/share/%d/%d?platform=%s", userID, artworkID, platform), nil
}
// 训练自定义风格模型
func (p *WXAIPaintingPlatform) TrainCustomStyle(ctx context.Context, req TrainCustomStyleRequest) (*TrainingTask, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 检查用户权限
canTrain, err := p.userManager.CanTrainModel(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("检查训练权限失败: %v", err)
}
if !canTrain {
return nil, errors.New("无权训练自定义模型")
}
// 3. 上传训练图片
var imageURLs []string
for _, img := range req.TrainingImages {
url, err := p.storageService.UploadImage(ctx, img)
if err != nil {
return nil, fmt.Errorf("上传训练图片失败: %v", err)
}
imageURLs = append(imageURLs, url)
}
// 4. 创建训练任务
task, err := p.modelTraining.CreateTrainingTask(ctx, user.ID, req.StyleName, req.StyleDescription, imageURLs, req.BaseModel)
if err != nil {
return nil, fmt.Errorf("创建训练任务失败: %v", err)
}
// 5. 扣减训练次数
if err := p.userManager.UseModelTraining(ctx, user.ID); err != nil {
log.Printf("扣减训练次数失败: %v", err)
}
// 6. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, user.ID, "train_model", map[string]interface{}{
"task_id": task.ID,
})
return task, nil
}
// 获取训练任务列表
func (p *WXAIPaintingPlatform) GetTrainingTasks(ctx context.Context, userID uint, status string) ([]TrainingTask, error) {
return p.modelTraining.GetUserTasks(ctx, userID, status)
}
// 获取训练任务详情
func (p *WXAIPaintingPlatform) GetTrainingTaskDetail(ctx context.Context, taskID uint) (*TrainingTaskDetail, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 获取任务信息
task, err := p.modelTraining.GetTask(ctx, taskID)
if err != nil {
return nil, fmt.Errorf("获取训练任务失败: %v", err)
}
// 3. 验证所有者
if task.UserID != user.ID {
return nil, errors.New("无权访问此任务")
}
// 4. 获取训练进度
progress, err := p.modelTraining.GetTrainingProgress(ctx, taskID)
if err != nil {
return nil, fmt.Errorf("获取训练进度失败: %v", err)
}
// 5. 获取示例输出
var sampleOutputs []string
if task.Status == "completed" {
sampleOutputs, err = p.modelTraining.GetSampleOutputs(ctx, taskID)
if err != nil {
return nil, fmt.Errorf("获取示例输出失败: %v", err)
}
}
return &TrainingTaskDetail{
Task: *task,
Progress: *progress,
SampleOutputs: sampleOutputs,
}, nil
}
// 使用自定义风格生成
func (p *WXAIPaintingPlatform) GenerateWithCustomStyle(ctx context.Context, req GenerateWithCustomStyleRequest) (*GenerateResult, error) {
// 1. 验证用户
user, ok := UserFromContext(ctx)
if !ok {
return nil, errors.New("未授权访问")
}
// 2. 检查用户生成权限
canGenerate, err := p.userManager.CanGenerate(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("检查生成权限失败: %v", err)
}
if !canGenerate {
return nil, errors.New("生成次数不足,请购买或等待重置")
}
// 3. 获取自定义风格
customStyle, err := p.modelTraining.GetCustomStyle(ctx, req.StyleID)
if err != nil {
return nil, fmt.Errorf("获取自定义风格失败: %v", err)
}
// 4. 验证所有者
if customStyle.UserID != user.ID {
return nil, errors.New("无权使用此自定义风格")
}
// 5. 内容安全检查
if err := p.contentFilter.CheckText(req.Prompt); err != nil {
return nil, fmt.Errorf("内容安全检查失败: %v", err)
}
// 6. 调用AI生成
startTime := time.Now()
imageData, err := p.aiPaintingEngine.GenerateWithCustomModel(ctx, req.Prompt, customStyle.ModelName, req.Params)
if err != nil {
return nil, fmt.Errorf("AI生成失败: %v", err)
}
generationTime := time.Since(startTime)
// 7. 保存生成结果
imageURL, err := p.storageService.UploadImage(ctx, imageData)
if err != nil {
return nil, fmt.Errorf("保存图片失败: %v", err)
}
// 8. 记录生成历史
history, err := p.artworkManager.RecordCustomGeneration(ctx, user.ID, req, imageURL, generationTime)
if err != nil {
return nil, fmt.Errorf("记录生成历史失败: %v", err)
}
// 9. 扣减生成次数
if err := p.userManager.UseFreeGeneration(ctx, user.ID); err != nil {
log.Printf("扣减生成次数失败: %v", err)
}
// 10. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, user.ID, "generate_with_custom", map[string]interface{}{
"style_id": req.StyleID,
"time_cost": generationTime.Seconds(),
})
return &GenerateResult{
ImageURL: imageURL,
HistoryID: history.ID,
TimeCost: generationTime,
CustomStyle: customStyle,
}, nil
}
// 购买生成次数
func (p *WXAIPaintingPlatform) PurchaseGenerations(ctx context.Context, userID uint, packID uint) (*Order, error) {
// 1. 获取生成次数包
generationPack, err := p.orderSystem.GetGenerationPack(ctx, packID)
if err != nil {
return nil, fmt.Errorf("获取生成次数包失败: %v", err)
}
// 2. 创建订单
order, err := p.orderSystem.CreateGenerationOrder(ctx, userID, generationPack)
if err != nil {
return nil, fmt.Errorf("创建订单失败: %v", err)
}
// 3. 发起支付
paymentParams, err := p.orderSystem.CreatePayment(ctx, order.ID)
if err != nil {
return nil, fmt.Errorf("创建支付失败: %v", err)
}
// 4. 记录用户行为
p.dataAnalytics.RecordUserAction(ctx, userID, "purchase_generations", map[string]interface{}{
"pack_id": packID,
"order_id": order.ID,
})
return &Order{
ID: order.ID,
OrderNo: order.OrderNo,
Amount: order.Amount,
PaymentInfo: paymentParams,
}, nil
}
// 处理支付回调
func (p *WXAIPaintingPlatform) HandlePaymentCallback(ctx context.Context, data PaymentCallbackData) error {
// 1. 验证支付结果
if err := p.orderSystem.VerifyPayment(data); err != nil {
return fmt.Errorf("支付验证失败: %v", err)
}
// 2. 获取订单
order, err := p.orderSystem.GetOrderByNo(ctx, data.OrderNo)
if err != nil {
return fmt.Errorf("获取订单失败: %v", err)
}
// 3. 更新订单状态
if err := p.orderSystem.CompleteOrder(ctx, order.ID); err != nil {
return fmt.Errorf("更新订单状态失败: %v", err)
}
// 4. 发放购买内容
switch order.Type {
case "generation_pack":
if err := p.userManager.AddGenerations(ctx, order.UserID, order.ItemID, order.Quantity); err != nil {
return fmt.Errorf("发放生成次数失败: %v", err)
}
case "vip":
if err := p.userManager.AddVIPDays(ctx, order.UserID, order.Quantity); err != nil {
return fmt.Errorf("发放VIP天数失败: %v", err)
}
case "model_training":
if err := p.userManager.AddModelTrainings(ctx, order.UserID, order.Quantity); err != nil {
return fmt.Errorf("发放训练次数失败: %v", err)
}
}
// 5. 发送通知
if err := p.notification.SendPaymentSuccess(ctx, order.UserID, order.ID); err != nil {
log.Printf("发送支付成功通知失败: %v", err)
}
// 6. 记录支付成功
p.dataAnalytics.RecordPayment(ctx, order.UserID, order.ID, order.Amount)
return nil
}
// 获取用户信息
func (p *WXAIPaintingPlatform) GetUserInfo(ctx context.Context, userID uint) (*UserInfo, error) {
// 1. 获取用户基本信息
user, err := p.userManager.GetUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("获取用户信息失败: %v", err)
}
// 2. 获取用户统计信息
stats, err := p.userManager.GetUserStats(ctx, userID)
if err != nil {
return nil, fmt.Errorf("获取用户统计信息失败: %v", err)
}
// 3. 获取VIP信息
vipInfo, err := p.userManager.GetVIPInfo(ctx, userID)
if err != nil {
return nil, fmt.Errorf("获取VIP信息失败: %v", err)
}
// 4. 获取生成次数
generations, err := p.userManager.GetGenerationsLeft(ctx, userID)
if err != nil {
return nil, fmt.Errorf("获取生成次数失败: %v", err)
}
// 5. 获取训练次数
trainings, err := p.userManager.GetTrainingsLeft(ctx, userID)
if err != nil {
return nil, fmt.Errorf("获取训练次数失败: %v", err)
}
return &UserInfo{
User: *user,
Stats: *stats,
VIPInfo: *vipInfo,
Generations: generations,
Trainings: trainings,
}, nil
}
// 获取首页数据
func (p *WXAIPaintingPlatform) GetHomePageData(ctx context.Context) (*HomePageData, error) {
var data HomePageData
var err error
// 1. 获取轮播图
data.Banners, err = p.styleManager.GetBanners(ctx)
if err != nil {
return nil, fmt.Errorf("获取轮播图失败: %v", err)
}
// 2. 获取推荐风格
data.RecommendedStyles, err = p.styleManager.GetRecommendedStyles(ctx, 8)
if err != nil {
return nil, fmt.Errorf("获取推荐风格失败: %v", err)
}
// 3. 获取热门作品
data.PopularArtworks, err = p.artworkManager.GetPopularArtworks(ctx, 0, 6)
if err != nil {
return nil, fmt.Errorf("获取热门作品失败: %v", err)
}
// 4. 获取最新作品
data.LatestArtworks, err = p.artworkManager.GetLatestArtworks(ctx, 0, 6)
if err != nil {
return nil, fmt.Errorf("获取最新作品失败: %v", err)
}
// 5. 获取热门标签
data.HotTags, err = p.artworkManager.GetHotTags(ctx, 10)
if err != nil {
return nil, fmt.Errorf("获取热门标签失败: %v", err)
}
// 6. 获取用户数据(如果已登录)
if user, ok := UserFromContext(ctx); ok {
data.UserStats, err = p.userManager.GetUserStats(ctx, user.ID)
if err != nil {
log.Printf("获取用户统计信息失败: %v", err)
}
}
return &data, nil
}
// 监控模型训练
func (p *WXAIPaintingPlatform) monitorModelTraining(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// 检查训练中的任务
tasks, err := p.modelTraining.GetRunningTasks(ctx)
if err != nil {
log.Printf("获取训练中任务失败: %v", err)
continue
}
for _, task := range tasks {
// 更新训练进度
progress, err := p.modelTraining.CheckTrainingProgress(ctx, task.ID)
if err != nil {
log.Printf("检查训练进度失败: %v", err)
continue
}
// 如果训练完成
if progress.IsCompleted {
// 保存训练模型
if err := p.modelTraining.CompleteTraining(ctx, task.ID, progress.ModelPath); err != nil {
log.Printf("完成训练任务失败: %v", err)
continue
}
// 生成示例图片
if err := p.modelTraining.GenerateSampleOutputs(ctx, task.ID); err != nil {
log.Printf("生成示例输出失败: %v", err)
}
// 发送通知
if err := p.notification.SendTrainingComplete(ctx, task.UserID, task.ID); err != nil {
log.Printf("发送训练完成通知失败: %v", err)
}
}
}
}
}
}
// 清理过期资源
func (p *WXAIPaintingPlatform) cleanExpiredResources(ctx context.Context) {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// 清理过期临时图片
if err := p.storageService.CleanExpiredImages(ctx); err != nil {
log.Printf("清理过期图片失败: %v", err)
}
// 清理未完成的训练任务
if err := p.modelTraining.CleanFailedTasks(ctx); err != nil {
log.Printf("清理失败任务失败: %v", err)
}
}
}
}
// 生成每日报表
func (p *WXAIPaintingPlatform) generateDailyReports(ctx context.Context) {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// 生成平台统计报告
report, err := p.dataAnalytics.GenerateDailyReport(ctx)
if err != nil {
log.Printf("生成日报失败: %v", err)
continue
}
// 发送给管理员
admins, err := p.userManager.GetAdminUsers(ctx)
if err != nil {
log.Printf("获取管理员列表失败: %v", err)
continue
}
for _, admin := range admins {
if err := p.notification.SendDailyReport(ctx, admin.ID, report); err != nil {
log.Printf("发送日报失败: %v", err)
}
}
}
}
}
// 启动HTTP服务
func (p *WXAIPaintingPlatform) StartHTTPServer(addr string) error {
r := mux.NewRouter()
// API路由
api := r.PathPrefix("/api/v1").Subrouter()
api.Use(p.authMiddleware)
// 用户相关
api.HandleFunc("/register", p.handleRegister).Methods("POST")
api.HandleFunc("/login", p.handleLogin).Methods("POST")
api.HandleFunc("/user/info", p.handleGetUserInfo).Methods("GET")
// AI生成相关
api.HandleFunc("/generate", p.handleGenerate).Methods("POST")
api.HandleFunc("/generate/from-image", p.handleGenerateFromImage).Methods("POST")
api.HandleFunc("/generate/history", p.handleGetGenerationHistory).Methods("GET")
api.HandleFunc("/generate/history/{id}", p.handleGetGenerationDetail).Methods("GET")
// 画廊相关
api.HandleFunc("/gallery/save", p.handleSaveToGallery).Methods("POST")
api.HandleFunc("/gallery/user", p.handleGetUserGallery).Methods("GET")
api.HandleFunc("/gallery/popular", p.handleGetPopularArtworks).Methods("GET")
api.HandleFunc("/gallery/artwork/{id}", p.handleGetArtworkDetail).Methods("GET")
api.HandleFunc("/gallery/artwork/{id}/like", p.handleLikeArtwork).Methods("POST")
api.HandleFunc("/gallery/artwork/{id}/unlike", p.handleUnlikeArtwork).Methods("POST")
api.HandleFunc("/gallery/artwork/{id}/comment", p.handleCommentArtwork).Methods("POST")
api.HandleFunc("/gallery/artwork/{id}/comments", p.handleGetArtworkComments).Methods("GET")
// 风格相关
api.HandleFunc("/styles", p.handleGetStyles).Methods("GET")
api.HandleFunc("/styles/popular", p.handleGetPopularStyles).Methods("GET")
api.HandleFunc("/styles/favorites", p.handleGetFavoriteStyles).Methods("GET")
api.HandleFunc("/styles/favorites/add", p.handleAddFavoriteStyle).Methods("POST")
api.HandleFunc("/styles/favorites/remove", p.handleRemoveFavoriteStyle).Methods("POST")
// 训练相关
api.HandleFunc("/train", p.handleTrainCustomStyle).Methods("POST")
api.HandleFunc("/train/tasks", p.handleGetTrainingTasks).Methods("GET")
api.HandleFunc("/train/tasks/{id}", p.handleGetTrainingTaskDetail).Methods("GET")
api.HandleFunc("/generate/custom", p.handleGenerateWithCustomStyle).Methods("POST")
// 支付相关
api.HandleFunc("/purchase/generations", p.handlePurchaseGenerations).Methods("POST")
api.HandleFunc("/payment/callback", p.handlePaymentCallback).Methods("POST")
// 首页数据
api.HandleFunc("/home", p.handleGetHomePageData).Methods("GET")
log.Printf("HTTP服务启动在 %s", addr)
return http.ListenAndServe(addr, r)
}
// 认证中间件
func (p *WXAIPaintingPlatform) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token == "" {
http.Error(w, "未授权访问", http.StatusUnauthorized)
return
}
user, err := p.userManager.VerifyToken(token)
if err != nil {
http.Error(w, "无效的令牌", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), "user", user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// 从上下文中获取用户
func UserFromContext(ctx context.Context) (*User, bool) {
user, ok := ctx.Value("user").(*User)
return user, ok
}
// 主函数
func main() {
// 初始化平台
platform, err := NewWXAIPaintingPlatform("user:pass@tcp(127.0.0.1:3306)/ai_painting?charset=utf8mb4&parseTime=True&loc=Local")
if err != nil {
log.Fatal("初始化平台失败:", err)
}
// 启动平台
ctx := context.Background()
if err := platform.Start(ctx); err != nil {
log.Fatal("平台启动失败:", err)
}
// 启动HTTP服务
if err := platform.StartHTTPServer(":8080"); err != nil {
log.Fatal("HTTP服务启动失败:", err)
}
}
使用说明
功能特点
-
多风格AI绘画生成:
- 支持多种艺术风格转换
- 文字描述生成图像
- 图片风格迁移
-
自定义模型训练:
- 用户上传图片训练专属风格
- 训练进度实时监控
- 模型效果预览
-
社交艺术社区:
- 作品分享与展示
- 点赞评论互动
- 热门作品推荐
-
商业化功能:
- 生成次数购买
- VIP会员服务
- 自定义模型训练包
-
内容安全:
- 图片内容审核
- 文字内容过滤
- 用户举报机制
核心组件
-
AI生成引擎:
- 多模型集成
- 分布式推理
- 生成队列管理
-
模型训练系统:
- 训练任务调度
- GPU资源管理
- 模型版本控制
-
艺术风格库:
- 风格分类管理
- 风格效果预览
- 用户收藏系统
-
社交互动系统:
- 作品展示
- 用户互动
- 内容推荐
-
商业化系统:
- 订单管理
- 支付集成
- VIP特权
使用方法
-
初始化平台:
platform, err := NewWXAIPaintingPlatform("数据库连接字符串") if err != nil { log.Fatal("初始化失败:", err) }
-
用户注册:
user, err := platform.RegisterUser(ctx, User{ WXOpenID: "微信openid", Nickname: "用户昵称", })
-
生成AI绘画:
result, err := platform.GenerateAIPainting(ctx, GenerateRequest{ Prompt: "星空下的城堡", StyleID: 123, })
-
训练自定义风格:
task, err := platform.TrainCustomStyle(ctx, TrainCustomStyleRequest{ StyleName: "我的专属风格", TrainingImages: [][]byte{...}, })
-
分享作品:
shareResult, err := platform.ShareArtwork(ctx, userID, artworkID, "wx")
应用场景
-
个人艺术创作:
- 将想法转化为艺术作品
- 照片艺术化处理
- 创作独特风格的图像
-
社交分享:
- 展示AI生成的艺术作品
- 参与艺术社区互动
- 发现灵感和创意
-
商业设计:
- 快速生成设计概念图
- 品牌视觉风格探索
- 营销素材创作
-
教育娱乐:
- 艺术风格学习
- AI绘画体验
- 创意激发工具
技术亮点
-
多模型集成:
- Stable Diffusion
- GAN模型
- 神经风格迁移
-
高效推理:
- 模型量化
- 批量推理
- 缓存机制
-
微信深度集成:
- 小程序登录
- 微信支付
- 社交分享
-
分布式训练:
- 多GPU并行
- 训练任务调度
- 资源监控
扩展建议
-
AR展示功能:
- AR画廊
- 作品实景展示
- 3D艺术空间
-
NFT集成:
- 作品上链
- 数字收藏品
- 区块链交易
-
协作创作:
- 多人联合创作
- 风格融合
- 实时协作
-
AI辅助工具:
- 自动调色
- 构图建议
- 创意提示生成
这个系统为微信小程序提供了一个完整的AI艺术创作平台,从简单的风格转换到复杂的自定义模型训练,满足了从普通用户到专业创作者的不同需求。
更多推荐
所有评论(0)