|
| 1 | +// Package aipaint ai绘图 |
| 2 | +package aipaint |
| 3 | + |
| 4 | +import ( |
| 5 | + "bytes" |
| 6 | + "encoding/base64" |
| 7 | + "encoding/json" |
| 8 | + "fmt" |
| 9 | + "image" |
| 10 | + "net/url" |
| 11 | + "os" |
| 12 | + "regexp" |
| 13 | + "strconv" |
| 14 | + "strings" |
| 15 | + |
| 16 | + "github.com/FloatTech/floatbox/binary" |
| 17 | + "github.com/FloatTech/floatbox/file" |
| 18 | + "github.com/FloatTech/floatbox/img/writer" |
| 19 | + "github.com/FloatTech/floatbox/web" |
| 20 | + ctrl "github.com/FloatTech/zbpctrl" |
| 21 | + "github.com/FloatTech/zbputils/control" |
| 22 | + "github.com/FloatTech/zbputils/ctxext" |
| 23 | + zero "github.com/wdvxdr1123/ZeroBot" |
| 24 | + "github.com/wdvxdr1123/ZeroBot/message" |
| 25 | +) |
| 26 | + |
| 27 | +var ( |
| 28 | + datapath string |
| 29 | + predictRe = regexp.MustCompile(`{"steps".+?}`) |
| 30 | + // 参考host http://91.217.139.190:5010 http://91.216.169.75:5010 |
| 31 | + aipaintTxt2ImgURL = "/got_image?token=%v&tags=%v" |
| 32 | + aipaintImg2ImgURL = "/got_image2image?token=%v&tags=%v" |
| 33 | + cfg = newServerConfig("data/aipaint/config.json") |
| 34 | +) |
| 35 | + |
| 36 | +type result struct { |
| 37 | + Steps int `json:"steps"` |
| 38 | + Sampler string `json:"sampler"` |
| 39 | + Seed int `json:"seed"` |
| 40 | + Strength float64 `json:"strength"` |
| 41 | + Noise float64 `json:"noise"` |
| 42 | + Scale float64 `json:"scale"` |
| 43 | + Uc string `json:"uc"` |
| 44 | +} |
| 45 | + |
| 46 | +func (r *result) String() string { |
| 47 | + return fmt.Sprintf("steps: %v\nsampler: %v\nseed: %v\nstrength: %v\nnoise: %v\nscale: %v\nuc: %v\n", r.Steps, r.Sampler, r.Seed, r.Strength, r.Noise, r.Scale, r.Uc) |
| 48 | +} |
| 49 | + |
| 50 | +func init() { // 插件主体 |
| 51 | + engine := control.Register("aipaint", &ctrl.Options[*zero.Ctx]{ |
| 52 | + DisableOnDefault: false, |
| 53 | + Help: "ai绘图\n" + |
| 54 | + "- [ ai绘图 | 生成色图 | 生成涩图 | ai画图 ] xxx\n" + |
| 55 | + "- [ 以图绘图 | 以图生图 | 以图画图 ] xxx [图片]|@xxx|[qq号]\n" + |
| 56 | + "- 设置ai绘图配置 [server] [token]\n" + |
| 57 | + "例1: 设置ai绘图配置 http://91.216.169.75:5010 abc\n" + |
| 58 | + "例2: 设置ai绘图配置 http://91.217.139.190:5010 abc\n" + |
| 59 | + "通过 http://91.217.139.190:5010/token 获取token", |
| 60 | + PrivateDataFolder: "aipaint", |
| 61 | + }) |
| 62 | + datapath = file.BOTPATH + "/" + engine.DataFolder() |
| 63 | + engine.OnPrefixGroup([]string{`ai绘图`, `生成色图`, `生成涩图`, `ai画图`}).SetBlock(true). |
| 64 | + Handle(func(ctx *zero.Ctx) { |
| 65 | + server, token, err := cfg.load() |
| 66 | + if err != nil { |
| 67 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 68 | + return |
| 69 | + } |
| 70 | + ctx.SendChain(message.Text("少女祈祷中...")) |
| 71 | + args := ctx.State["args"].(string) |
| 72 | + data, err := web.GetData(server + fmt.Sprintf(aipaintTxt2ImgURL, token, url.QueryEscape(strings.TrimSpace(strings.ReplaceAll(args, " ", "%20"))))) |
| 73 | + if err != nil { |
| 74 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 75 | + return |
| 76 | + } |
| 77 | + sendAiImg(ctx, data) |
| 78 | + }) |
| 79 | + engine.OnRegex(`^(以图绘图|以图生图|以图画图)[\s\S]*?(\[CQ:(image\,file=([0-9a-zA-Z]{32}).*|at.+?(\d{5,11}))\].*|(\d+))$`).SetBlock(true). |
| 80 | + Handle(func(ctx *zero.Ctx) { |
| 81 | + server, token, err := cfg.load() |
| 82 | + if err != nil { |
| 83 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 84 | + return |
| 85 | + } |
| 86 | + c := newContext(ctx.Event.UserID) |
| 87 | + list := ctx.State["regex_matched"].([]string) |
| 88 | + err = c.prepareLogos(list[4]+list[5]+list[6], strconv.FormatInt(ctx.Event.UserID, 10)) |
| 89 | + if err != nil { |
| 90 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 91 | + return |
| 92 | + } |
| 93 | + args := strings.TrimSuffix(strings.TrimPrefix(list[0], list[1]), list[2]) |
| 94 | + if args == "" { |
| 95 | + ctx.SendChain(message.Text("ERROR: 以图绘图必须添加tag")) |
| 96 | + return |
| 97 | + } |
| 98 | + ctx.SendChain(message.Text("少女祈祷中...")) |
| 99 | + postURL := server + fmt.Sprintf(aipaintImg2ImgURL, token, url.QueryEscape(strings.TrimSpace(strings.ReplaceAll(args, " ", "%20")))) |
| 100 | + |
| 101 | + f, err := os.Open(c.headimgsdir[0]) |
| 102 | + if err != nil { |
| 103 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 104 | + return |
| 105 | + } |
| 106 | + defer f.Close() |
| 107 | + |
| 108 | + img, _, err := image.Decode(f) |
| 109 | + if err != nil { |
| 110 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 111 | + return |
| 112 | + } |
| 113 | + imageShape := "" |
| 114 | + switch { |
| 115 | + case img.Bounds().Dx() > img.Bounds().Dy(): |
| 116 | + imageShape = "Landscape" |
| 117 | + case img.Bounds().Dx() == img.Bounds().Dy(): |
| 118 | + imageShape = "Square" |
| 119 | + default: |
| 120 | + imageShape = "Portrait" |
| 121 | + } |
| 122 | + |
| 123 | + // 图片转base64 |
| 124 | + base64Bytes, err := writer.ToBase64(img) |
| 125 | + if err != nil { |
| 126 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 127 | + return |
| 128 | + } |
| 129 | + data, err := web.PostData(postURL+"&shape="+imageShape, "text/plain", bytes.NewReader(base64Bytes)) |
| 130 | + if err != nil { |
| 131 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 132 | + return |
| 133 | + } |
| 134 | + sendAiImg(ctx, data) |
| 135 | + }) |
| 136 | + engine.OnRegex(`^设置ai绘图配置\s(.*[^\s$])\s(.+)$`, zero.SuperUserPermission).SetBlock(true). |
| 137 | + Handle(func(ctx *zero.Ctx) { |
| 138 | + regexMatched := ctx.State["regex_matched"].([]string) |
| 139 | + err := cfg.save(regexMatched[1], regexMatched[2]) |
| 140 | + if err != nil { |
| 141 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 142 | + return |
| 143 | + } |
| 144 | + ctx.SendChain(message.Text("成功设置server为", regexMatched[1], ", token为", regexMatched[2])) |
| 145 | + }) |
| 146 | +} |
| 147 | + |
| 148 | +func sendAiImg(ctx *zero.Ctx, data []byte) { |
| 149 | + var loadData string |
| 150 | + if predictRe.MatchString(binary.BytesToString(data)) { |
| 151 | + loadData = predictRe.FindStringSubmatch(binary.BytesToString(data))[0] |
| 152 | + } |
| 153 | + var r result |
| 154 | + if loadData != "" { |
| 155 | + err := json.Unmarshal(binary.StringToBytes(loadData), &r) |
| 156 | + if err != nil { |
| 157 | + ctx.SendChain(message.Text("ERROR: ", err)) |
| 158 | + return |
| 159 | + } |
| 160 | + } |
| 161 | + encodeStr := base64.StdEncoding.EncodeToString(data) |
| 162 | + m := message.Message{ctxext.FakeSenderForwardNode(ctx, message.Image("base64://"+encodeStr))} |
| 163 | + m = append(m, ctxext.FakeSenderForwardNode(ctx, message.Text(r.String()))) |
| 164 | + if id := ctx.Send(m).ID(); id == 0 { |
| 165 | + ctx.SendChain(message.Text("ERROR: 可能被风控或下载图片用时过长,请耐心等待")) |
| 166 | + } |
| 167 | +} |
0 commit comments