WDL(Wide Deep Learning for Recommender Systems)——Google经典CTR预估模型

一、文章简介

        Wide & Deep Learning for Recommender Systems这篇文章介绍了一种结合宽线性模型和深度神经网络的方法,以实现推荐系统中的记忆和泛化。这种方法在Google Play商店的应用推荐系统中进行了评估,展示了其显著的性能提升。

推荐系统中的记忆和泛化

        为了实现记忆和泛化,Wide & Deep模型结合了宽线性模型和深度神经网络:

1.宽组件(Wide Component)

        宽组件的主要功能是实现记忆,即捕捉特征之间的频繁共现关系。这部分模型采用线性模型,利用交叉乘积特征来捕捉特征之间的高阶关系。

1). 原始输入特征和交叉乘积特征
  • 原始输入特征:这些是从用户和上下文数据中提取的直接特征。例如,用户的安装应用、语言、年龄等。
  • 交叉乘积特征:通过交叉乘积转换生成的新特征,这些特征通过组合原始特征来捕捉特征间的交互。例如,“AND(gender=female, language=en)”表示女性用户使用英语。
2). 公式

        宽组件的线性组合公式:

$\operatorname{Wide}(\mathbf{x})=\mathbf{w}_{\text {wide }}^T[\mathbf{x}, \phi(\mathbf{x})]$

其中:

  • \mathbf{x} 是原始输入特征向量。
  •  $\phi(\mathbf{x})$是交叉乘积特征向量。
  • $\mathbf{W}_{\text {wide }}$ 是宽组件的权重向量。
3). 记忆功能

        宽组件通过权重向量$\mathbf{W}_{\text {wide }}$​ 学习特征间的共现关系。例如,如果某用户安装了Netflix且展示了Pandora,则特征“AND(user_installed_app=netflix, impression_app=pandora)”的值为1,模型可以利用这个信息来进行记忆。

2.深组件(Deep Component)

        深组件的主要功能是实现泛化,即学习特征之间的潜在关系,处理未见过的新特征组合。深组件通过深度神经网络来实现,能够更好地捕捉复杂的非线性关系。

1).嵌入层

        类别特征嵌入:将高维稀疏的类别特征转化为低维稠密的嵌入向量。每个类别特征(如“language=en”)被映射到一个32维的嵌入向量。公式:

\mathbf{e} = Embedding (x)

其中,\mathbf{e} 是嵌入向量,\mathbf{x}是类别特征。

2).隐藏层
  • 连接嵌入和稠密特征:将所有嵌入向量和稠密特征连接在一起,形成一个约1200维的稠密向量。
  • 多层感知器:通过多层感知器(MLP)进行处理,通常包括3个ReLU层,每层执行非线性变换,捕捉复杂的特征关系

$\mathbf{a}^{(l)}=f\left(\mathbf{W}^{(l)} \mathbf{a}^{(l-1)}+\mathbf{b}^{(l)}\right)$

其中:

  • \mathbf{a}^{(l)}是第l层的激活值。
  • \mathbf{W}^{(l)}是第l层的权重矩阵。
  • \mathbf{b}^{(l)}是第l层的偏置向量。
  • f是激活函数,通常为ReLU
3).泛化功能

        深组件通过嵌入层和多层感知器学习特征之间的非线性关系,能够处理以前未见过的新特征组合。例如,通过学习用户的行为模式和上下文信息,模型可以生成新的推荐。

3).实例代码
import tensorflow as tf# 创建一个简单的模型,包括一个嵌入层、一个隐藏层和一个输出层
model = tf.keras.Sequential([tf.keras.layers.Embedding(input_dim=4, output_dim=32, input_length=1),tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation='relu'),  # 隐藏层tf.keras.layers.Dense(1)  # 输出层
])# 编译模型
model.compile(optimizer='adam', loss='mse')# 打印嵌入层的权重(训练前)
print("嵌入层权重(训练前):")
print(model.layers[0].get_weights()[0])# 创建简单的数据
import numpy as np
x_train = np.array([[0], [1], [2], [3]])
y_train = np.array([1.0, 2.0, 3.0, 4.0])# 训练模型
model.fit(x_train, y_train, epochs=100, verbose=0)# 打印嵌入层的权重(训练后)
print("嵌入层权重(训练后):")
print(model.layers[0].get_weights()[0])

3.结合记忆和泛化

        宽组件和深组件的输出通过加权和进行组合,作为最终的预测结果。在训练过程中,这两部分是同时优化的,使得模型能够平衡记忆和泛化的需求。具体过程如下:

1).计算宽组件的输出:

        宽组件的输出是原始输入特征和交叉乘积特征的线性组合:

$\operatorname{Wide}(\mathbf{x})=\mathbf{w}_{\text {wide }}^T[\mathbf{x}, \phi(\mathbf{x})]$

2).计算深组件的输出

深组件的输出是嵌入层和多层感知器处理后的结果:

$\operatorname{Deep}(\mathbf{x})=\mathbf{w}_{\text {deep }}^T \mathbf{a}^{\left(l_f\right)}$

其中, $\mathbf{a}^{\left(l_f\right)}$是深度模型最后一层的激活值。

3).组合输出

        宽组件和深组件的输出通过加权和进行组合,作为最终的预测值:

$P(Y=1 \mid \mathbf{x})=\sigma\left(\mathbf{w}_{\text {wide }}^T[\mathbf{x}, \phi(\mathbf{x})]+\mathbf{w}_{\text {deep }}^T \mathbf{a}^{\left(l_f\right)}+b\right)$

其中,$\sigma$ 是sigmoid激活函数,\mathbf{w}_{\text{wide}}\mathbf{w}_{\text{deep}}分别是宽组件和深组件的权重向量,$\mathbf{a}^{\left(l_f\right)}$是深组件最后一层的激活值,b是偏置项。

4).损失函数和优化

        使用逻辑损失函数(logistic loss function)进行联合训练,通过反向传播算法同时优化宽组件和深组件的参数:

L=-\frac{1}{N} \sum_{i=1}^N\left[y_i \log \left(\hat{y}_i\right)+\left(1-y_i\right) \log \left(1-\hat{y}_i\right)\right]

其中:

  • N 是样本的数量。
  • y_i​ 是第i个样本的实际标签(0 或 1)。
  • \hat{y}_i是第i个样本的预测概率,即样本属于类别 1 的概率。
  • \log是自然对数。
损失函数的意义
  • 当实际标签y_i 为 1 时,损失函数的第一项 y_i \log \left(\hat{y}_i\right)起作用,第二项为零。这部分损失鼓励模型将\hat{y}_i 尽可能地接近 1。
  • 当实际标签 y_i为 0 时,损失函数的第二项\left(1-y_i \right )\log\left(1-\hat{y}_i \right ) 起作用,第一项为零。这部分损失鼓励模型将 \hat{y}_i尽可能地接近 0。

通过最小化这个损失函数,模型会在预测时更加准确地反映实际标签。

逻辑损失函数的特性
  • 凸性:逻辑损失函数是一个凸函数,这意味着存在全局最优解(证明见下一篇博客)。
  • 概率解释:逻辑损失函数直接反映了模型预测概率的准确性,能够有效处理不平衡数据集。

4.结论与意义

  • Wide & Deep模型成功结合了记忆和泛化的优势,在推荐系统中表现出色。
  • 实际应用中,通过在线实验验证了其有效性和改进。
  • 提供了开源实现,为进一步研究和应用提供了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3249648.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

食堂采购系统开发:从需求分析到上线实施的完整指南

本篇文章,笔者将详细介绍食堂采购系统从需求分析到上线实施的完整过程,旨在为开发团队和管理者提供一个系统化的指南。 一、需求分析 1.用户需求 常见的需求包括: -采购计划管理 -供应商管理 -库存管理 -成本控制 -报表生成 2.系统功…

RK3568笔记四十:设备树

若该文为原创文章,转载请注明原文出处。 一、介绍 设备树 (Device Tree) 的作用就是描述一个硬件平台的硬件资源,一般描述那些不能动态探测到的设备,可以被动态探测到的设备是不需要描述。 设备树可以被 bootloader(uboot) 传递到内核&#x…

CentOS6minimal安装nginx-1.26.1.tar.gz 笔记240718

CentOS6安装新版nginx 240718, CentOS6.1-minimal 安装 nginx-1.26.1.tar.gz 下载 nginx-1.26.1.tar.gz 的页面 : https://nginx.org/en/download.html 下载 nginx-1.26.1.tar.gz : https://nginx.org/download/nginx-1.26.1.tar.gz CentOS6.1已过期, 给它更换yum源, 将下面…

SpringCloud------Sentinel(微服务保护)

目录 雪崩问题 处理方式!!!技术选型 Sentinel 启动命令使用步骤引入依赖配置控制台地址 访问微服务触发监控 限流规则------故障预防流控模式流控效果 FeignClient整合Sentinel线程隔离-------故障处理线程池隔离和信号量隔离​编辑 两种方式优缺点设置方式 熔断降级-----…

简述乐观锁和悲观锁——Java

悲观锁和乐观锁 悲观就是任何事都认为会往坏处发生,乐观就是认为任何事都会往好处发生。 打个比方,假如一个公司里只有一台打印机,如果多个人同时打印文件,可能出现混乱的问题,他的资料打印在了我的资料上&#xff0…

【视频讲解】神经网络、Lasso回归、线性回归、随机森林、ARIMA股票价格时间序列预测|附代码数据

全文链接:https://tecdat.cn/?p37019 分析师:Haopeng Li 随着我国股票市场规模的不断扩大、制度的不断完善,它在金融市场中也成为了越来越不可或缺的一部分。 【视频讲解】神经网络、Lasso回归、线性回归、随机森林、ARIMA股票价格时间序列…

Unity动画系统(4)

6.3 动画系统高级1-1_哔哩哔哩_bilibili p333- 声音组件添加 using System.Collections; using System.Collections.Generic; using UnityEngine; public class RobotAnimationController : MonoBehaviour { [Header("平滑过渡时间")] [Range(0,3)] publ…

AI智能名片S2B2C商城小程序在社群去中心化管理中的应用与价值深度探索

摘要:随着互联网技术的飞速发展,社群经济作为一种新兴的商业模式,正逐渐成为企业与用户之间建立深度连接、促进商业增长的重要途径。本文深入探讨了AI智能名片S2B2C商城小程序在社群去中心化管理中的应用,通过详细分析社群去中心化…

科普文:多线程如何使用CPU缓存?

一、前言 计算机的基础知识聊的比较少,但想要更好的理解多线程以及为后续多线程的介绍做铺垫,所以有必要单独开一篇来聊一下 CPU cache。 二、CPU 前面有一篇文章关于 CPU是如何进行计算 感兴趣的同学,可以先移步了解一下,不了…

连接Redis异常:JedisMovedDataException

redis.clients.jedis.exceptions.JedisMovedDataException: MOVED 5798 192.168.187.138:6379 在使用JAVA API连接redis的时候,出现了异常: 问题的原因 JAVA API实现是redis集群实现方式,而在配置文中就配置的是单结点的方式。 Moved表示使…

从人工巡检到智能防控:智慧油气田安全生产的新视角

一、背景需求 随着科技的飞速发展,视频监控技术已成为各行各业保障安全生产、提升管理效率的重要手段。特别是在油气田这一特殊领域,由于其工作环境复杂、安全风险高,传统的监控方式已难以满足实际需求。因此,基于视频监控AI智能…

NLP教程:1 词袋模型和TFIDF模型

文章目录 词袋模型TF-IDF模型词汇表模型 词袋模型 文本特征提取有两个非常重要的模型: 词集模型:单词构成的集合,集合自然每个元素都只有一个,也即词集中的每个单词都只有一个。 词袋模型:在词集的基础上如果一个单词…

springcolud学习04Ribbon

Ribbon Ribbon是一个用于构建分布式系统的开源项目,最初由Netflix开发。它是一个基于HTTP和TCP客户端负载均衡器,用于将客户端的请求分发到多个服务实例上,以提高系统的性能和可靠性。Ribbon提供了许多负载均衡算法和配置选项,可…

maven内网依赖包编译报错问题的一种解决方法

背景 外网开发时可以连接互联网,所以编译没有什么问题,但是将数据库、代码、maven仓库全部拷贝到内网,搭建内网环境之后,编译失败。 此依赖包的依赖层级图 maven镜像库配置使用拷贝到内网的本地库,配置如下&#xff…

WebRTC音视频-前言介绍

目录 效果预期 1:WebRTC相关简介 1.1:WebRTC和RTC 1.2:WebRTC前景和应用 2:WebRTC通话原理 2.1:媒体协商 2.2:网络协商 2.3:信令服务器 效果预期 1:WebRTC相关简介 1.1&…

24位动态信号采集卡8路同步音频震动信号采集IEPE采集卡USB8814

24位动态信号采集卡 音频震动信号采集USB8814实测演示 品牌:阿尔泰科技 产品概述: USB8814 是一款为测试音频和振动信号而设计的高精度数据采集卡。该板卡提供 8 路同步模拟输 入通道,24bit 分辨率,单通道采样速率zui高 204.8kSP…

4.定时器

原理 时钟源:定时器是内部时钟源(晶振),计数器是外部计时长度:对应TH TL计数器初值寄存器(高八位,低八位)对应的中断触发函数 中断源中断处理函数Timer0Timer0_Routine(void) interrupt 1Timer1Timer1_Routine(void) …

css list布局 高端玩法

这种布局方式 通常父级item 使用display:flex; 子集list使用margin-right margin-bottom撑开距离 然后得纠结最后一个子集的margin什么的 有个新思路子集使用padding <div class"video-box"><div class"video-list" v-for"item in videoLis…

系统架构设计师教程(清华第二版) 第3章 信息系统基础知识-3.3 管理信息系统(MIS)-解读

系统架构设计师教程 第3章 信息系统基础知识-3.3 管理信息系统(MIS) 3.3.1 管理信息系统的概念3.3.1.1 部件组成3.3.1.2 结构分类3.3.1.2.1 开环结构3.3.1.2.2 闭环结构3.3.1.3 金字塔结构3.3.2 管理信息系统的功能3.3.3 管理信息系统的组成3.3.3.1 销售市场子系统3.3.3.2…

前端学习(二)之HTML

一、HTML文件结构 <!DOCTYPE html> <!-- 告诉浏览器&#xff0c;这是一个HTML文件 --><html lang"en"> <!-- 根元素&#xff08;起始点&#xff0c;最外层容器&#xff09; --><head> <!-- 文档的头部&#xff08;元信息&#xff…